Bawil commited on
Commit
5199058
·
verified ·
1 Parent(s): 4266ba2

Upload 31 files

Browse files
Files changed (31) hide show
  1. models/for_WMH_Vent/class_weights/class_weights_fold0_standard_3class.json +27 -0
  2. models/for_WMH_Vent/class_weights/class_weights_fold1_standard_3class.json +27 -0
  3. models/for_WMH_Vent/class_weights/class_weights_fold2_standard_3class.json +27 -0
  4. models/for_WMH_Vent/class_weights/class_weights_fold3_standard_3class.json +27 -0
  5. models/for_WMH_Vent/data_splits/concat_fold_assignments.json +475 -0
  6. models/for_WMH_Vent/data_splits/fold_assignments.json +543 -0
  7. models/for_WMH_Vent/data_splits/for_assignment.py +234 -0
  8. models/for_WMH_Vent/data_splits/local_fold_assignments.json +421 -0
  9. models/for_WMH_Vent/data_splits/public_fold_assignments.json +102 -0
  10. models/for_WMH_Vent/download_models.txt +1 -0
  11. models/for_WMH_Vent/folds_results_zscore2_all/per_class_summary.csv +9 -0
  12. models/for_WMH_Vent/folds_results_zscore2_all/test_metrics_all_variants_folds.csv +27 -0
  13. models/for_WMH_Vent/folds_results_zscore2_all/training_info_all_variants_folds.csv +17 -0
  14. models/for_WMH_Vent/folds_results_zscore2_all/variant_comparison_test.csv +5 -0
  15. models/for_WMH_Vent/folds_results_zscore2_all/variant_comparison_training.csv +5 -0
  16. models/for_WMH_Vent/model_training_scripts/attn_unet_model.py +85 -0
  17. models/for_WMH_Vent/model_training_scripts/base_runner_all.py +23 -0
  18. models/for_WMH_Vent/model_training_scripts/dlv3_unet_model.py +198 -0
  19. models/for_WMH_Vent/model_training_scripts/dlv3_unet_model_GN.py +247 -0
  20. models/for_WMH_Vent/model_training_scripts/p4_compute_class_weights.py +353 -0
  21. models/for_WMH_Vent/model_training_scripts/p4_data_loader.py +912 -0
  22. models/for_WMH_Vent/model_training_scripts/p4_error_analysis.py +1033 -0
  23. models/for_WMH_Vent/model_training_scripts/p4_folds_results_aggregator.py +611 -0
  24. models/for_WMH_Vent/model_training_scripts/p4_inference.py +1146 -0
  25. models/for_WMH_Vent/model_training_scripts/p4_run_experiments_all.py +576 -0
  26. models/for_WMH_Vent/model_training_scripts/p4_unet_viz.py +640 -0
  27. models/for_WMH_Vent/model_training_scripts/p4_variant_all_net.py +1051 -0
  28. models/for_WMH_Vent/model_training_scripts/trans_unet_model.py +125 -0
  29. models/for_WMH_Vent/model_training_scripts/unet_model.py +87 -0
  30. models/for_WMH_Vent/model_training_scripts/utility_functions.py +96 -0
  31. models/for_WMH_Vent/results_fold_avg_var_1_zscore2/models/standard_3class/download_models.txt +1 -0
models/for_WMH_Vent/class_weights/class_weights_fold0_standard_3class.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fold_id": 0,
3
+ "class_scenario": "3class",
4
+ "preprocessing": "standard",
5
+ "num_classes": 3,
6
+ "total_pixels": 119144448,
7
+ "class_pixel_counts": [
8
+ 118420367,
9
+ 496384,
10
+ 227697
11
+ ],
12
+ "class_frequencies": [
13
+ 0.993922662682528,
14
+ 0.004166236936193619,
15
+ 0.0019111003812783622
16
+ ],
17
+ "class_weights": [
18
+ 0.003950922707703595,
19
+ 0.9423307646632635,
20
+ 2.0537183126290333
21
+ ],
22
+ "class_names": [
23
+ "Background",
24
+ "Ventricles",
25
+ "Abnormal WMH"
26
+ ]
27
+ }
models/for_WMH_Vent/class_weights/class_weights_fold1_standard_3class.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fold_id": 1,
3
+ "class_scenario": "3class",
4
+ "preprocessing": "standard",
5
+ "num_classes": 3,
6
+ "total_pixels": 119341056,
7
+ "class_pixel_counts": [
8
+ 118646442,
9
+ 470627,
10
+ 223987
11
+ ],
12
+ "class_frequencies": [
13
+ 0.994179588958891,
14
+ 0.003943546469037445,
15
+ 0.0018768645720714924
16
+ ],
17
+ "class_weights": [
18
+ 0.003834061426337229,
19
+ 0.96633402011123,
20
+ 2.029831918462433
21
+ ],
22
+ "class_names": [
23
+ "Background",
24
+ "Ventricles",
25
+ "Abnormal WMH"
26
+ ]
27
+ }
models/for_WMH_Vent/class_weights/class_weights_fold2_standard_3class.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fold_id": 2,
3
+ "class_scenario": "3class",
4
+ "preprocessing": "standard",
5
+ "num_classes": 3,
6
+ "total_pixels": 119472128,
7
+ "class_pixel_counts": [
8
+ 118787277,
9
+ 464952,
10
+ 219899
11
+ ],
12
+ "class_frequencies": [
13
+ 0.994267692293888,
14
+ 0.0038917194142553484,
15
+ 0.001840588291856658
16
+ ],
17
+ "class_weights": [
18
+ 0.0037673539050414257,
19
+ 0.9622481463361134,
20
+ 2.033984499758845
21
+ ],
22
+ "class_names": [
23
+ "Background",
24
+ "Ventricles",
25
+ "Abnormal WMH"
26
+ ]
27
+ }
models/for_WMH_Vent/class_weights/class_weights_fold3_standard_3class.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fold_id": 3,
3
+ "class_scenario": "3class",
4
+ "preprocessing": "standard",
5
+ "num_classes": 3,
6
+ "total_pixels": 119734272,
7
+ "class_pixel_counts": [
8
+ 118973104,
9
+ 509903,
10
+ 251265
11
+ ],
12
+ "class_frequencies": [
13
+ 0.9936428560738232,
14
+ 0.004258621959132971,
15
+ 0.0020985219670438216
16
+ ],
17
+ "class_weights": [
18
+ 0.004240031541573218,
19
+ 0.9890739908996539,
20
+ 2.006685977558773
21
+ ],
22
+ "class_names": [
23
+ "Background",
24
+ "Ventricles",
25
+ "Abnormal WMH"
26
+ ]
27
+ }
models/for_WMH_Vent/data_splits/concat_fold_assignments.json ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "datasets": [
4
+ "Local_SAI",
5
+ "Public_MSSEG"
6
+ ],
7
+ "total_patients": 115,
8
+ "test_patients": 13,
9
+ "trainval_patients": 102,
10
+ "local_split": "70/10/20",
11
+ "public_split": "60/20/20",
12
+ "n_folds": 4,
13
+ "random_seed": 42
14
+ },
15
+ "test_set": {
16
+ "patients": [
17
+ "110012",
18
+ "105549",
19
+ "109816",
20
+ "105074",
21
+ "106780",
22
+ "107680",
23
+ "108807",
24
+ "106063",
25
+ "114585",
26
+ "111489",
27
+ "c01p04",
28
+ "c07p05",
29
+ "c08p04"
30
+ ],
31
+ "n_patients": 13
32
+ },
33
+ "folds": {
34
+ "fold_0": {
35
+ "train_patients": [
36
+ "109395",
37
+ "115788",
38
+ "113845",
39
+ "114770",
40
+ "102313",
41
+ "104797",
42
+ "111189",
43
+ "105597",
44
+ "111140",
45
+ "106270",
46
+ "114836",
47
+ "108295",
48
+ "104518",
49
+ "110218",
50
+ "110784",
51
+ "101627",
52
+ "104280",
53
+ "107966",
54
+ "101228",
55
+ "104420",
56
+ "109944",
57
+ "114903",
58
+ "112765",
59
+ "106200",
60
+ "106506",
61
+ "106536",
62
+ "112055",
63
+ "104447",
64
+ "106976",
65
+ "105978",
66
+ "110543",
67
+ "114058",
68
+ "113394",
69
+ "107739",
70
+ "112657",
71
+ "111008",
72
+ "105911",
73
+ "111852",
74
+ "105465",
75
+ "114128",
76
+ "110280",
77
+ "112414",
78
+ "105302",
79
+ "107455",
80
+ "110327",
81
+ "114990",
82
+ "112730",
83
+ "104453",
84
+ "111691",
85
+ "114454",
86
+ "104474",
87
+ "104252",
88
+ "109654",
89
+ "104937",
90
+ "104871",
91
+ "107508",
92
+ "114525",
93
+ "115588",
94
+ "110540",
95
+ "109267",
96
+ "107539",
97
+ "108344",
98
+ "112659",
99
+ "112776",
100
+ "113046",
101
+ "107233",
102
+ "102035",
103
+ "106905",
104
+ "107997",
105
+ "112378",
106
+ "104520",
107
+ "106639",
108
+ "104670",
109
+ "104899",
110
+ "115628",
111
+ "108444",
112
+ "109923",
113
+ "110157",
114
+ "114304",
115
+ "114266",
116
+ "c08p03",
117
+ "c01p01",
118
+ "c08p02",
119
+ "c07p03",
120
+ "c07p04",
121
+ "c01p02",
122
+ "c07p01",
123
+ "c08p05",
124
+ "c07p02"
125
+ ],
126
+ "val_patients": [
127
+ "108726",
128
+ "105917",
129
+ "105755",
130
+ "109141",
131
+ "110497",
132
+ "112997",
133
+ "104810",
134
+ "108975",
135
+ "107130",
136
+ "107630",
137
+ "c01p05",
138
+ "c08p01",
139
+ "c01p03"
140
+ ],
141
+ "n_train": 89,
142
+ "n_val": 13
143
+ },
144
+ "fold_1": {
145
+ "train_patients": [
146
+ "108726",
147
+ "105917",
148
+ "105755",
149
+ "109141",
150
+ "110497",
151
+ "112997",
152
+ "104810",
153
+ "108975",
154
+ "107130",
155
+ "107630",
156
+ "114836",
157
+ "108295",
158
+ "104518",
159
+ "110218",
160
+ "110784",
161
+ "101627",
162
+ "104280",
163
+ "107966",
164
+ "101228",
165
+ "104420",
166
+ "109944",
167
+ "114903",
168
+ "112765",
169
+ "106200",
170
+ "106506",
171
+ "106536",
172
+ "112055",
173
+ "104447",
174
+ "106976",
175
+ "105978",
176
+ "110543",
177
+ "114058",
178
+ "113394",
179
+ "107739",
180
+ "112657",
181
+ "111008",
182
+ "105911",
183
+ "111852",
184
+ "105465",
185
+ "114128",
186
+ "110280",
187
+ "112414",
188
+ "105302",
189
+ "107455",
190
+ "110327",
191
+ "114990",
192
+ "112730",
193
+ "104453",
194
+ "111691",
195
+ "114454",
196
+ "104474",
197
+ "104252",
198
+ "109654",
199
+ "104937",
200
+ "104871",
201
+ "107508",
202
+ "114525",
203
+ "115588",
204
+ "110540",
205
+ "109267",
206
+ "107539",
207
+ "108344",
208
+ "112659",
209
+ "112776",
210
+ "113046",
211
+ "107233",
212
+ "102035",
213
+ "106905",
214
+ "107997",
215
+ "112378",
216
+ "104520",
217
+ "106639",
218
+ "104670",
219
+ "104899",
220
+ "115628",
221
+ "108444",
222
+ "109923",
223
+ "110157",
224
+ "114304",
225
+ "114266",
226
+ "c01p05",
227
+ "c08p01",
228
+ "c01p03",
229
+ "c07p03",
230
+ "c07p04",
231
+ "c01p02",
232
+ "c07p01",
233
+ "c08p05",
234
+ "c07p02"
235
+ ],
236
+ "val_patients": [
237
+ "109395",
238
+ "115788",
239
+ "113845",
240
+ "114770",
241
+ "102313",
242
+ "104797",
243
+ "111189",
244
+ "105597",
245
+ "111140",
246
+ "106270",
247
+ "c08p03",
248
+ "c01p01",
249
+ "c08p02"
250
+ ],
251
+ "n_train": 89,
252
+ "n_val": 13
253
+ },
254
+ "fold_2": {
255
+ "train_patients": [
256
+ "108726",
257
+ "105917",
258
+ "105755",
259
+ "109141",
260
+ "110497",
261
+ "112997",
262
+ "104810",
263
+ "108975",
264
+ "107130",
265
+ "107630",
266
+ "109395",
267
+ "115788",
268
+ "113845",
269
+ "114770",
270
+ "102313",
271
+ "104797",
272
+ "111189",
273
+ "105597",
274
+ "111140",
275
+ "106270",
276
+ "109944",
277
+ "114903",
278
+ "112765",
279
+ "106200",
280
+ "106506",
281
+ "106536",
282
+ "112055",
283
+ "104447",
284
+ "106976",
285
+ "105978",
286
+ "110543",
287
+ "114058",
288
+ "113394",
289
+ "107739",
290
+ "112657",
291
+ "111008",
292
+ "105911",
293
+ "111852",
294
+ "105465",
295
+ "114128",
296
+ "110280",
297
+ "112414",
298
+ "105302",
299
+ "107455",
300
+ "110327",
301
+ "114990",
302
+ "112730",
303
+ "104453",
304
+ "111691",
305
+ "114454",
306
+ "104474",
307
+ "104252",
308
+ "109654",
309
+ "104937",
310
+ "104871",
311
+ "107508",
312
+ "114525",
313
+ "115588",
314
+ "110540",
315
+ "109267",
316
+ "107539",
317
+ "108344",
318
+ "112659",
319
+ "112776",
320
+ "113046",
321
+ "107233",
322
+ "102035",
323
+ "106905",
324
+ "107997",
325
+ "112378",
326
+ "104520",
327
+ "106639",
328
+ "104670",
329
+ "104899",
330
+ "115628",
331
+ "108444",
332
+ "109923",
333
+ "110157",
334
+ "114304",
335
+ "114266",
336
+ "c01p05",
337
+ "c08p01",
338
+ "c01p03",
339
+ "c08p03",
340
+ "c01p01",
341
+ "c08p02",
342
+ "c07p01",
343
+ "c08p05",
344
+ "c07p02"
345
+ ],
346
+ "val_patients": [
347
+ "114836",
348
+ "108295",
349
+ "104518",
350
+ "110218",
351
+ "110784",
352
+ "101627",
353
+ "104280",
354
+ "107966",
355
+ "101228",
356
+ "104420",
357
+ "c07p03",
358
+ "c07p04",
359
+ "c01p02"
360
+ ],
361
+ "n_train": 89,
362
+ "n_val": 13
363
+ },
364
+ "fold_3": {
365
+ "train_patients": [
366
+ "108726",
367
+ "105917",
368
+ "105755",
369
+ "109141",
370
+ "110497",
371
+ "112997",
372
+ "104810",
373
+ "108975",
374
+ "107130",
375
+ "107630",
376
+ "109395",
377
+ "115788",
378
+ "113845",
379
+ "114770",
380
+ "102313",
381
+ "104797",
382
+ "111189",
383
+ "105597",
384
+ "111140",
385
+ "106270",
386
+ "114836",
387
+ "108295",
388
+ "104518",
389
+ "110218",
390
+ "110784",
391
+ "101627",
392
+ "104280",
393
+ "107966",
394
+ "101228",
395
+ "104420",
396
+ "110543",
397
+ "114058",
398
+ "113394",
399
+ "107739",
400
+ "112657",
401
+ "111008",
402
+ "105911",
403
+ "111852",
404
+ "105465",
405
+ "114128",
406
+ "110280",
407
+ "112414",
408
+ "105302",
409
+ "107455",
410
+ "110327",
411
+ "114990",
412
+ "112730",
413
+ "104453",
414
+ "111691",
415
+ "114454",
416
+ "104474",
417
+ "104252",
418
+ "109654",
419
+ "104937",
420
+ "104871",
421
+ "107508",
422
+ "114525",
423
+ "115588",
424
+ "110540",
425
+ "109267",
426
+ "107539",
427
+ "108344",
428
+ "112659",
429
+ "112776",
430
+ "113046",
431
+ "107233",
432
+ "102035",
433
+ "106905",
434
+ "107997",
435
+ "112378",
436
+ "104520",
437
+ "106639",
438
+ "104670",
439
+ "104899",
440
+ "115628",
441
+ "108444",
442
+ "109923",
443
+ "110157",
444
+ "114304",
445
+ "114266",
446
+ "c01p05",
447
+ "c08p01",
448
+ "c01p03",
449
+ "c08p03",
450
+ "c01p01",
451
+ "c08p02",
452
+ "c07p03",
453
+ "c07p04",
454
+ "c01p02"
455
+ ],
456
+ "val_patients": [
457
+ "109944",
458
+ "114903",
459
+ "112765",
460
+ "106200",
461
+ "106506",
462
+ "106536",
463
+ "112055",
464
+ "104447",
465
+ "106976",
466
+ "105978",
467
+ "c07p01",
468
+ "c08p05",
469
+ "c07p02"
470
+ ],
471
+ "n_train": 89,
472
+ "n_val": 13
473
+ }
474
+ }
475
+ }
models/for_WMH_Vent/data_splits/fold_assignments.json ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_patients": 115,
4
+ "test_patients": 23,
5
+ "trainval_patients": 92,
6
+ "n_folds": 5,
7
+ "random_seed": 42,
8
+ "datasets": [
9
+ "Local_SAI",
10
+ "Public_MSSEG"
11
+ ]
12
+ },
13
+ "test_set": {
14
+ "patients": [
15
+ "112776",
16
+ "104252",
17
+ "107539",
18
+ "111140",
19
+ "104518",
20
+ "107997",
21
+ "111189",
22
+ "110543",
23
+ "108344",
24
+ "104520",
25
+ "c01p01",
26
+ "107130",
27
+ "113394",
28
+ "c08p04",
29
+ "105074",
30
+ "101228",
31
+ "111691",
32
+ "105978",
33
+ "c07p01",
34
+ "109267",
35
+ "114836",
36
+ "c08p03",
37
+ "104670"
38
+ ],
39
+ "n_patients": 23
40
+ },
41
+ "folds": {
42
+ "fold_0": {
43
+ "train_patients": [
44
+ "102035",
45
+ "102313",
46
+ "104280",
47
+ "104447",
48
+ "104453",
49
+ "104474",
50
+ "104797",
51
+ "104810",
52
+ "104899",
53
+ "105302",
54
+ "105465",
55
+ "105549",
56
+ "105597",
57
+ "105755",
58
+ "105917",
59
+ "106063",
60
+ "106200",
61
+ "106506",
62
+ "106536",
63
+ "106639",
64
+ "106905",
65
+ "107233",
66
+ "107455",
67
+ "107508",
68
+ "107630",
69
+ "107680",
70
+ "107739",
71
+ "108295",
72
+ "108444",
73
+ "108726",
74
+ "109141",
75
+ "109395",
76
+ "109654",
77
+ "109923",
78
+ "109944",
79
+ "110012",
80
+ "110157",
81
+ "110280",
82
+ "110327",
83
+ "110497",
84
+ "110540",
85
+ "110784",
86
+ "111489",
87
+ "111852",
88
+ "112055",
89
+ "112378",
90
+ "112414",
91
+ "112657",
92
+ "112730",
93
+ "112765",
94
+ "112997",
95
+ "113046",
96
+ "114058",
97
+ "114128",
98
+ "114266",
99
+ "114304",
100
+ "114525",
101
+ "114585",
102
+ "114770",
103
+ "114903",
104
+ "114990",
105
+ "115588",
106
+ "115628",
107
+ "115788",
108
+ "c01p02",
109
+ "c01p03",
110
+ "c01p05",
111
+ "c07p02",
112
+ "c07p03",
113
+ "c07p04",
114
+ "c07p05",
115
+ "c08p02",
116
+ "c08p05"
117
+ ],
118
+ "val_patients": [
119
+ "101627",
120
+ "104420",
121
+ "104871",
122
+ "104937",
123
+ "105911",
124
+ "106270",
125
+ "106780",
126
+ "106976",
127
+ "107966",
128
+ "108807",
129
+ "108975",
130
+ "109816",
131
+ "110218",
132
+ "111008",
133
+ "112659",
134
+ "113845",
135
+ "114454",
136
+ "c01p04",
137
+ "c08p01"
138
+ ],
139
+ "n_train": 73,
140
+ "n_val": 19
141
+ },
142
+ "fold_1": {
143
+ "train_patients": [
144
+ "101627",
145
+ "102035",
146
+ "102313",
147
+ "104280",
148
+ "104420",
149
+ "104453",
150
+ "104474",
151
+ "104797",
152
+ "104871",
153
+ "104937",
154
+ "105302",
155
+ "105465",
156
+ "105755",
157
+ "105911",
158
+ "105917",
159
+ "106063",
160
+ "106200",
161
+ "106270",
162
+ "106506",
163
+ "106536",
164
+ "106639",
165
+ "106780",
166
+ "106905",
167
+ "106976",
168
+ "107233",
169
+ "107630",
170
+ "107966",
171
+ "108295",
172
+ "108444",
173
+ "108726",
174
+ "108807",
175
+ "108975",
176
+ "109141",
177
+ "109654",
178
+ "109816",
179
+ "109944",
180
+ "110157",
181
+ "110218",
182
+ "110280",
183
+ "110327",
184
+ "110497",
185
+ "110540",
186
+ "110784",
187
+ "111008",
188
+ "111489",
189
+ "111852",
190
+ "112055",
191
+ "112378",
192
+ "112414",
193
+ "112657",
194
+ "112659",
195
+ "112730",
196
+ "112765",
197
+ "113845",
198
+ "114304",
199
+ "114454",
200
+ "114525",
201
+ "114585",
202
+ "114770",
203
+ "114903",
204
+ "115628",
205
+ "115788",
206
+ "c01p02",
207
+ "c01p03",
208
+ "c01p04",
209
+ "c01p05",
210
+ "c07p02",
211
+ "c07p03",
212
+ "c07p04",
213
+ "c07p05",
214
+ "c08p01",
215
+ "c08p02",
216
+ "c08p05"
217
+ ],
218
+ "val_patients": [
219
+ "104447",
220
+ "104810",
221
+ "104899",
222
+ "105549",
223
+ "105597",
224
+ "107455",
225
+ "107508",
226
+ "107680",
227
+ "107739",
228
+ "109395",
229
+ "109923",
230
+ "110012",
231
+ "112997",
232
+ "113046",
233
+ "114058",
234
+ "114128",
235
+ "114266",
236
+ "114990",
237
+ "115588"
238
+ ],
239
+ "n_train": 73,
240
+ "n_val": 19
241
+ },
242
+ "fold_2": {
243
+ "train_patients": [
244
+ "101627",
245
+ "102035",
246
+ "102313",
247
+ "104420",
248
+ "104447",
249
+ "104810",
250
+ "104871",
251
+ "104899",
252
+ "104937",
253
+ "105465",
254
+ "105549",
255
+ "105597",
256
+ "105911",
257
+ "106063",
258
+ "106200",
259
+ "106270",
260
+ "106506",
261
+ "106780",
262
+ "106976",
263
+ "107233",
264
+ "107455",
265
+ "107508",
266
+ "107630",
267
+ "107680",
268
+ "107739",
269
+ "107966",
270
+ "108444",
271
+ "108807",
272
+ "108975",
273
+ "109141",
274
+ "109395",
275
+ "109654",
276
+ "109816",
277
+ "109923",
278
+ "109944",
279
+ "110012",
280
+ "110157",
281
+ "110218",
282
+ "110280",
283
+ "110327",
284
+ "110497",
285
+ "110784",
286
+ "111008",
287
+ "111489",
288
+ "111852",
289
+ "112055",
290
+ "112378",
291
+ "112414",
292
+ "112657",
293
+ "112659",
294
+ "112730",
295
+ "112765",
296
+ "112997",
297
+ "113046",
298
+ "113845",
299
+ "114058",
300
+ "114128",
301
+ "114266",
302
+ "114304",
303
+ "114454",
304
+ "114585",
305
+ "114770",
306
+ "114990",
307
+ "115588",
308
+ "115628",
309
+ "c01p02",
310
+ "c01p03",
311
+ "c01p04",
312
+ "c01p05",
313
+ "c07p02",
314
+ "c07p04",
315
+ "c08p01",
316
+ "c08p02",
317
+ "c08p05"
318
+ ],
319
+ "val_patients": [
320
+ "104280",
321
+ "104453",
322
+ "104474",
323
+ "104797",
324
+ "105302",
325
+ "105755",
326
+ "105917",
327
+ "106536",
328
+ "106639",
329
+ "106905",
330
+ "108295",
331
+ "108726",
332
+ "110540",
333
+ "114525",
334
+ "114903",
335
+ "115788",
336
+ "c07p03",
337
+ "c07p05"
338
+ ],
339
+ "n_train": 74,
340
+ "n_val": 18
341
+ },
342
+ "fold_3": {
343
+ "train_patients": [
344
+ "101627",
345
+ "102035",
346
+ "102313",
347
+ "104280",
348
+ "104420",
349
+ "104447",
350
+ "104453",
351
+ "104474",
352
+ "104797",
353
+ "104810",
354
+ "104871",
355
+ "104899",
356
+ "104937",
357
+ "105302",
358
+ "105465",
359
+ "105549",
360
+ "105597",
361
+ "105755",
362
+ "105911",
363
+ "105917",
364
+ "106063",
365
+ "106200",
366
+ "106270",
367
+ "106506",
368
+ "106536",
369
+ "106639",
370
+ "106780",
371
+ "106905",
372
+ "106976",
373
+ "107233",
374
+ "107455",
375
+ "107508",
376
+ "107680",
377
+ "107739",
378
+ "107966",
379
+ "108295",
380
+ "108444",
381
+ "108726",
382
+ "108807",
383
+ "108975",
384
+ "109395",
385
+ "109816",
386
+ "109923",
387
+ "110012",
388
+ "110218",
389
+ "110327",
390
+ "110497",
391
+ "110540",
392
+ "111008",
393
+ "112378",
394
+ "112414",
395
+ "112659",
396
+ "112730",
397
+ "112997",
398
+ "113046",
399
+ "113845",
400
+ "114058",
401
+ "114128",
402
+ "114266",
403
+ "114304",
404
+ "114454",
405
+ "114525",
406
+ "114585",
407
+ "114903",
408
+ "114990",
409
+ "115588",
410
+ "115628",
411
+ "115788",
412
+ "c01p03",
413
+ "c01p04",
414
+ "c07p02",
415
+ "c07p03",
416
+ "c07p05",
417
+ "c08p01"
418
+ ],
419
+ "val_patients": [
420
+ "107630",
421
+ "109141",
422
+ "109654",
423
+ "109944",
424
+ "110157",
425
+ "110280",
426
+ "110784",
427
+ "111489",
428
+ "111852",
429
+ "112055",
430
+ "112657",
431
+ "112765",
432
+ "114770",
433
+ "c01p02",
434
+ "c01p05",
435
+ "c07p04",
436
+ "c08p02",
437
+ "c08p05"
438
+ ],
439
+ "n_train": 74,
440
+ "n_val": 18
441
+ },
442
+ "fold_4": {
443
+ "train_patients": [
444
+ "101627",
445
+ "104280",
446
+ "104420",
447
+ "104447",
448
+ "104453",
449
+ "104474",
450
+ "104797",
451
+ "104810",
452
+ "104871",
453
+ "104899",
454
+ "104937",
455
+ "105302",
456
+ "105549",
457
+ "105597",
458
+ "105755",
459
+ "105911",
460
+ "105917",
461
+ "106270",
462
+ "106536",
463
+ "106639",
464
+ "106780",
465
+ "106905",
466
+ "106976",
467
+ "107455",
468
+ "107508",
469
+ "107630",
470
+ "107680",
471
+ "107739",
472
+ "107966",
473
+ "108295",
474
+ "108726",
475
+ "108807",
476
+ "108975",
477
+ "109141",
478
+ "109395",
479
+ "109654",
480
+ "109816",
481
+ "109923",
482
+ "109944",
483
+ "110012",
484
+ "110157",
485
+ "110218",
486
+ "110280",
487
+ "110540",
488
+ "110784",
489
+ "111008",
490
+ "111489",
491
+ "111852",
492
+ "112055",
493
+ "112657",
494
+ "112659",
495
+ "112765",
496
+ "112997",
497
+ "113046",
498
+ "113845",
499
+ "114058",
500
+ "114128",
501
+ "114266",
502
+ "114454",
503
+ "114525",
504
+ "114770",
505
+ "114903",
506
+ "114990",
507
+ "115588",
508
+ "115788",
509
+ "c01p02",
510
+ "c01p04",
511
+ "c01p05",
512
+ "c07p03",
513
+ "c07p04",
514
+ "c07p05",
515
+ "c08p01",
516
+ "c08p02",
517
+ "c08p05"
518
+ ],
519
+ "val_patients": [
520
+ "102035",
521
+ "102313",
522
+ "105465",
523
+ "106063",
524
+ "106200",
525
+ "106506",
526
+ "107233",
527
+ "108444",
528
+ "110327",
529
+ "110497",
530
+ "112378",
531
+ "112414",
532
+ "112730",
533
+ "114304",
534
+ "114585",
535
+ "115628",
536
+ "c01p03",
537
+ "c07p02"
538
+ ],
539
+ "n_train": 74,
540
+ "n_val": 18
541
+ }
542
+ }
543
+ }
models/for_WMH_Vent/data_splits/for_assignment.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import numpy as np
4
+ from sklearn.model_selection import KFold
5
+
6
+ # ─────────────────────────────────────────────
7
+ # Patient IDs
8
+ # ─────────────────────────────────────────────
9
+ local_patients_id = [
10
+ '101228', '101627', '102035', '102313', '104252', '104280', '104420',
11
+ '104447', '104453', '104474', '104518', '104520', '104670', '104797',
12
+ '104810', '104871', '104899', '104937', '105074', '105302', '105465',
13
+ '105549', '105597', '105755', '105911', '105917', '105978', '106063',
14
+ '106200', '106270', '106506', '106536', '106639', '106780', '106905',
15
+ '106976', '107130', '107233', '107455', '107508', '107539', '107630',
16
+ '107680', '107739', '107966', '107997', '108295', '108344', '108444',
17
+ '108726', '108807', '108975', '109141', '109267', '109395', '109654',
18
+ '109816', '109923', '109944', '110012', '110157', '110218', '110280',
19
+ '110327', '110497', '110540', '110543', '110784', '111008', '111140',
20
+ '111189', '111489', '111691', '111852', '112055', '112378', '112414',
21
+ '112657', '112659', '112730', '112765', '112776', '112997', '113046',
22
+ '113394', '113845', '114058', '114128', '114266', '114304', '114454',
23
+ '114525', '114585', '114770', '114836', '114903', '114990', '115588',
24
+ '115628', '115788',
25
+ ]
26
+
27
+ public_patients_id = [
28
+ 'c01p01', 'c01p02', 'c01p03', 'c01p04', 'c01p05',
29
+ 'c07p01', 'c07p02', 'c07p03', 'c07p04', 'c07p05',
30
+ 'c08p01', 'c08p02', 'c08p03', 'c08p04', 'c08p05',
31
+ ]
32
+
33
+ RANDOM_SEED = 42
34
+ N_FOLDS = 4
35
+
36
+
37
+ # ─────────────────────────────────────────────────────────────────────────────
38
+ # make_folds_exact (LOCAL)
39
+ # Carves n_val_per_fold * n_folds patients as an exclusive val pool,
40
+ # then rotates the val window. Val sets are perfectly non-overlapping.
41
+ # ─────────────────────────────────────────────────────────────────────────────
42
+ def make_folds_exact(trainval, n_val_per_fold, n_folds, rng):
43
+ arr = np.array(trainval)
44
+ rng.shuffle(arr)
45
+
46
+ total_val_pool = n_folds * n_val_per_fold # 5 * 10 = 50
47
+ assert total_val_pool <= len(arr), (
48
+ f"Not enough trainval ({len(arr)}) for {n_folds} x {n_val_per_fold} val = {total_val_pool}"
49
+ )
50
+ val_pool = arr[:total_val_pool] # 50 dedicated val patients
51
+ train_base = arr[total_val_pool:] # 29 always-train patients
52
+
53
+ folds = {}
54
+ for fold_idx in range(n_folds):
55
+ val_pts = val_pool[fold_idx * n_val_per_fold:(fold_idx + 1) * n_val_per_fold].tolist()
56
+ other_val = np.concatenate([
57
+ val_pool[:fold_idx * n_val_per_fold],
58
+ val_pool[(fold_idx + 1) * n_val_per_fold:]
59
+ ])
60
+ train_pts = np.concatenate([other_val, train_base]).tolist()
61
+ folds[f"fold_{fold_idx}"] = {
62
+ "train_patients": train_pts,
63
+ "val_patients": val_pts,
64
+ "n_train": len(train_pts),
65
+ "n_val": len(val_pts),
66
+ }
67
+ return folds
68
+
69
+
70
+ # ─────────────────────────────────────────────────────────────────────────────
71
+ # make_folds_kfold (PUBLIC)
72
+ # With only 12 trainval patients and 5 folds, KFold is the only way to keep
73
+ # val sets strictly non-overlapping. Val sizes will be 3,3,2,2,2.
74
+ # (5 * 3 = 15 > 12, so exact 3 per fold is mathematically impossible without
75
+ # overlap; KFold is the standard, correct solution.)
76
+ # ─────────────────────────────────────────────────────────────────────────────
77
+ def make_folds_kfold(trainval, n_folds, rng):
78
+ arr = np.array(trainval)
79
+ rng.shuffle(arr)
80
+
81
+ kf = KFold(n_splits=n_folds, shuffle=False) # arr already shuffled
82
+ folds = {}
83
+ for fold_idx, (train_idx, val_idx) in enumerate(kf.split(arr)):
84
+ folds[f"fold_{fold_idx}"] = {
85
+ "train_patients": arr[train_idx].tolist(),
86
+ "val_patients": arr[val_idx].tolist(),
87
+ "n_train": len(train_idx),
88
+ "n_val": len(val_idx),
89
+ }
90
+ return folds
91
+
92
+
93
+ # ─────────────────────────────────────────────────────────────────────────���───
94
+ # LOCAL -- 70 / 10 / 20
95
+ # 99 total -> test=20, val=10 per fold, train=69 per fold
96
+ # ─────────────────────────────────────────────────────────────────────────────
97
+ n_local = len(local_patients_id) # 99
98
+ n_local_test = round(n_local * 0.20) # 20
99
+ n_local_val_per_fold = round(n_local * 0.10) # 10
100
+
101
+ rng_local = np.random.default_rng(RANDOM_SEED)
102
+ local_arr = np.array(local_patients_id)
103
+ rng_local.shuffle(local_arr)
104
+
105
+ local_test = local_arr[:n_local_test].tolist() # 20
106
+ local_trainval = local_arr[n_local_test:].tolist() # 79
107
+
108
+ local_folds = make_folds_exact(
109
+ local_trainval,
110
+ n_val_per_fold=n_local_val_per_fold,
111
+ n_folds=N_FOLDS,
112
+ rng=np.random.default_rng(RANDOM_SEED + 1),
113
+ )
114
+
115
+ local_split = {
116
+ "metadata": {
117
+ "dataset": "Local_SAI",
118
+ "total_patients": n_local,
119
+ "test_patients": n_local_test,
120
+ "trainval_patients": len(local_trainval),
121
+ "target_split": "70/10/20 (train/val/test)",
122
+ "exact_counts": "train=69, val=10, test=20 per fold",
123
+ "n_folds": N_FOLDS,
124
+ "random_seed": RANDOM_SEED,
125
+ },
126
+ "test_set": {"patients": local_test, "n_patients": n_local_test},
127
+ "folds": local_folds,
128
+ }
129
+
130
+ # ─────────────────────────────────────────────────────────────────────────────
131
+ # PUBLIC -- 60 / 20 / 20
132
+ # 15 total -> test=3 (center-balanced), trainval=12
133
+ # KFold(5) on 12 -> val sizes: 3,3,2,2,2 (non-overlapping, closest to 20%)
134
+ # train sizes: 9,9,10,10,10
135
+ # ─────────────────────────────────────────────────────────────────────────────
136
+ n_public = len(public_patients_id) # 15
137
+
138
+ # Center-balanced test: 1 patient per center
139
+ centers = {}
140
+ for pid in public_patients_id:
141
+ centers.setdefault(pid[:3], []).append(pid)
142
+
143
+ public_test = []
144
+ public_trainval = []
145
+ for center, pids in sorted(centers.items()):
146
+ arr = np.array(pids)
147
+ np.random.default_rng(RANDOM_SEED + hash(center) % 1000).shuffle(arr)
148
+ public_test.append(arr[0]) # 1 test per center -> 3 total
149
+ public_trainval += arr[1:].tolist() # 4 trainval per center -> 12 total
150
+
151
+ public_folds = make_folds_kfold(
152
+ public_trainval,
153
+ n_folds=N_FOLDS,
154
+ rng=np.random.default_rng(RANDOM_SEED + 2),
155
+ )
156
+
157
+ public_split = {
158
+ "metadata": {
159
+ "dataset": "Public_MSSEG",
160
+ "total_patients": n_public,
161
+ "test_patients": len(public_test),
162
+ "trainval_patients": len(public_trainval),
163
+ "target_split": "60/20/20 (train/val/test)",
164
+ "n_folds": N_FOLDS,
165
+ "random_seed": RANDOM_SEED,
166
+ "center_balanced_test": True,
167
+ },
168
+ "test_set": {"patients": public_test, "n_patients": len(public_test)},
169
+ "folds": public_folds,
170
+ }
171
+
172
+ # ─────────────────────────────────────────────────────────────────────────────
173
+ # CONCATENATED
174
+ # ─────────────────────────────────────────────────────────────────────────────
175
+ concat_test = local_test + public_test
176
+ concat_folds = {}
177
+ for fold_key in local_folds:
178
+ lf = local_folds[fold_key]
179
+ pf = public_folds[fold_key]
180
+ concat_folds[fold_key] = {
181
+ "train_patients": lf["train_patients"] + pf["train_patients"],
182
+ "val_patients": lf["val_patients"] + pf["val_patients"],
183
+ "n_train": lf["n_train"] + pf["n_train"],
184
+ "n_val": lf["n_val"] + pf["n_val"],
185
+ }
186
+
187
+ concat_split = {
188
+ "metadata": {
189
+ "datasets": ["Local_SAI", "Public_MSSEG"],
190
+ "total_patients": n_local + n_public,
191
+ "test_patients": len(concat_test),
192
+ "trainval_patients": len(local_trainval) + len(public_trainval),
193
+ "local_split": "70/10/20",
194
+ "public_split": "60/20/20",
195
+ "n_folds": N_FOLDS,
196
+ "random_seed": RANDOM_SEED,
197
+ },
198
+ "test_set": {"patients": concat_test, "n_patients": len(concat_test)},
199
+ "folds": concat_folds,
200
+ }
201
+
202
+ # ─────────────────────────────────────────────────────────────────────────────
203
+ # Save
204
+ # ───────────────��─────────────────────────────────────────────────────────────
205
+ output_dir = os.path.dirname(os.path.abspath(__file__))
206
+
207
+ for name, data in [
208
+ ("local_fold_assignments.json", local_split),
209
+ ("public_fold_assignments.json", public_split),
210
+ ("concat_fold_assignments.json", concat_split),
211
+ ]:
212
+ path = os.path.join(output_dir, name)
213
+ with open(path, "w") as f:
214
+ json.dump(data, f, indent=2)
215
+ print(f"Saved: {path}")
216
+
217
+ # ─────────────────────────────────────────────────────────────────────────────
218
+ # Sanity check
219
+ # ─────────────────────────────────────────────────────────────────────────────
220
+ print("\n=== SANITY CHECK ===")
221
+ for label, split_data in [("LOCAL", local_split), ("PUBLIC", public_split), ("CONCAT", concat_split)]:
222
+ test_pts = set(split_data["test_set"]["patients"])
223
+ print(f"\n{label} (test={len(test_pts)})")
224
+ val_sets = []
225
+ for fold_key, fold in split_data["folds"].items():
226
+ train_pts = set(fold["train_patients"])
227
+ val_pts = set(fold["val_patients"])
228
+ val_sets.append(val_pts)
229
+ tv_overlap = len(train_pts & val_pts)
230
+ tst_overlap = len((train_pts | val_pts) & test_pts)
231
+ print(f" {fold_key}: train={len(train_pts):3d}, val={len(val_pts):2d} | "
232
+ f"train/val overlap={tv_overlap} | (train+val)/test overlap={tst_overlap}")
233
+ bad = [f"f{i}&f{j}" for i in range(len(val_sets)) for j in range(i+1, len(val_sets)) if val_sets[i] & val_sets[j]]
234
+ print(f" Val sets unique across folds: {'FAIL: ' + str(bad) if bad else 'OK'}")
models/for_WMH_Vent/data_splits/local_fold_assignments.json ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "dataset": "Local_SAI",
4
+ "total_patients": 100,
5
+ "test_patients": 10,
6
+ "trainval_patients": 90,
7
+ "target_split": "70/10/20 (train/val/test)",
8
+ "exact_counts": "train=70, val=10, test=20 per fold",
9
+ "n_folds": 4,
10
+ "random_seed": 42
11
+ },
12
+ "test_set": {
13
+ "patients": [
14
+ "110012",
15
+ "105549",
16
+ "109816",
17
+ "105074",
18
+ "106780",
19
+ "107680",
20
+ "108807",
21
+ "106063",
22
+ "114585",
23
+ "111489"
24
+ ],
25
+ "n_patients": 10
26
+ },
27
+ "folds": {
28
+ "fold_0": {
29
+ "train_patients": [
30
+ "109395",
31
+ "115788",
32
+ "113845",
33
+ "114770",
34
+ "102313",
35
+ "104797",
36
+ "111189",
37
+ "105597",
38
+ "111140",
39
+ "106270",
40
+ "114836",
41
+ "108295",
42
+ "104518",
43
+ "110218",
44
+ "110784",
45
+ "101627",
46
+ "104280",
47
+ "107966",
48
+ "101228",
49
+ "104420",
50
+ "109944",
51
+ "114903",
52
+ "112765",
53
+ "106200",
54
+ "106506",
55
+ "106536",
56
+ "112055",
57
+ "104447",
58
+ "106976",
59
+ "105978",
60
+ "110543",
61
+ "114058",
62
+ "113394",
63
+ "107739",
64
+ "112657",
65
+ "111008",
66
+ "105911",
67
+ "111852",
68
+ "105465",
69
+ "114128",
70
+ "110280",
71
+ "112414",
72
+ "105302",
73
+ "107455",
74
+ "110327",
75
+ "114990",
76
+ "112730",
77
+ "104453",
78
+ "111691",
79
+ "114454",
80
+ "104474",
81
+ "104252",
82
+ "109654",
83
+ "104937",
84
+ "104871",
85
+ "107508",
86
+ "114525",
87
+ "115588",
88
+ "110540",
89
+ "109267",
90
+ "107539",
91
+ "108344",
92
+ "112659",
93
+ "112776",
94
+ "113046",
95
+ "107233",
96
+ "102035",
97
+ "106905",
98
+ "107997",
99
+ "112378",
100
+ "104520",
101
+ "106639",
102
+ "104670",
103
+ "104899",
104
+ "115628",
105
+ "108444",
106
+ "109923",
107
+ "110157",
108
+ "114304",
109
+ "114266"
110
+ ],
111
+ "val_patients": [
112
+ "108726",
113
+ "105917",
114
+ "105755",
115
+ "109141",
116
+ "110497",
117
+ "112997",
118
+ "104810",
119
+ "108975",
120
+ "107130",
121
+ "107630"
122
+ ],
123
+ "n_train": 80,
124
+ "n_val": 10
125
+ },
126
+ "fold_1": {
127
+ "train_patients": [
128
+ "108726",
129
+ "105917",
130
+ "105755",
131
+ "109141",
132
+ "110497",
133
+ "112997",
134
+ "104810",
135
+ "108975",
136
+ "107130",
137
+ "107630",
138
+ "114836",
139
+ "108295",
140
+ "104518",
141
+ "110218",
142
+ "110784",
143
+ "101627",
144
+ "104280",
145
+ "107966",
146
+ "101228",
147
+ "104420",
148
+ "109944",
149
+ "114903",
150
+ "112765",
151
+ "106200",
152
+ "106506",
153
+ "106536",
154
+ "112055",
155
+ "104447",
156
+ "106976",
157
+ "105978",
158
+ "110543",
159
+ "114058",
160
+ "113394",
161
+ "107739",
162
+ "112657",
163
+ "111008",
164
+ "105911",
165
+ "111852",
166
+ "105465",
167
+ "114128",
168
+ "110280",
169
+ "112414",
170
+ "105302",
171
+ "107455",
172
+ "110327",
173
+ "114990",
174
+ "112730",
175
+ "104453",
176
+ "111691",
177
+ "114454",
178
+ "104474",
179
+ "104252",
180
+ "109654",
181
+ "104937",
182
+ "104871",
183
+ "107508",
184
+ "114525",
185
+ "115588",
186
+ "110540",
187
+ "109267",
188
+ "107539",
189
+ "108344",
190
+ "112659",
191
+ "112776",
192
+ "113046",
193
+ "107233",
194
+ "102035",
195
+ "106905",
196
+ "107997",
197
+ "112378",
198
+ "104520",
199
+ "106639",
200
+ "104670",
201
+ "104899",
202
+ "115628",
203
+ "108444",
204
+ "109923",
205
+ "110157",
206
+ "114304",
207
+ "114266"
208
+ ],
209
+ "val_patients": [
210
+ "109395",
211
+ "115788",
212
+ "113845",
213
+ "114770",
214
+ "102313",
215
+ "104797",
216
+ "111189",
217
+ "105597",
218
+ "111140",
219
+ "106270"
220
+ ],
221
+ "n_train": 80,
222
+ "n_val": 10
223
+ },
224
+ "fold_2": {
225
+ "train_patients": [
226
+ "108726",
227
+ "105917",
228
+ "105755",
229
+ "109141",
230
+ "110497",
231
+ "112997",
232
+ "104810",
233
+ "108975",
234
+ "107130",
235
+ "107630",
236
+ "109395",
237
+ "115788",
238
+ "113845",
239
+ "114770",
240
+ "102313",
241
+ "104797",
242
+ "111189",
243
+ "105597",
244
+ "111140",
245
+ "106270",
246
+ "109944",
247
+ "114903",
248
+ "112765",
249
+ "106200",
250
+ "106506",
251
+ "106536",
252
+ "112055",
253
+ "104447",
254
+ "106976",
255
+ "105978",
256
+ "110543",
257
+ "114058",
258
+ "113394",
259
+ "107739",
260
+ "112657",
261
+ "111008",
262
+ "105911",
263
+ "111852",
264
+ "105465",
265
+ "114128",
266
+ "110280",
267
+ "112414",
268
+ "105302",
269
+ "107455",
270
+ "110327",
271
+ "114990",
272
+ "112730",
273
+ "104453",
274
+ "111691",
275
+ "114454",
276
+ "104474",
277
+ "104252",
278
+ "109654",
279
+ "104937",
280
+ "104871",
281
+ "107508",
282
+ "114525",
283
+ "115588",
284
+ "110540",
285
+ "109267",
286
+ "107539",
287
+ "108344",
288
+ "112659",
289
+ "112776",
290
+ "113046",
291
+ "107233",
292
+ "102035",
293
+ "106905",
294
+ "107997",
295
+ "112378",
296
+ "104520",
297
+ "106639",
298
+ "104670",
299
+ "104899",
300
+ "115628",
301
+ "108444",
302
+ "109923",
303
+ "110157",
304
+ "114304",
305
+ "114266"
306
+ ],
307
+ "val_patients": [
308
+ "114836",
309
+ "108295",
310
+ "104518",
311
+ "110218",
312
+ "110784",
313
+ "101627",
314
+ "104280",
315
+ "107966",
316
+ "101228",
317
+ "104420"
318
+ ],
319
+ "n_train": 80,
320
+ "n_val": 10
321
+ },
322
+ "fold_3": {
323
+ "train_patients": [
324
+ "108726",
325
+ "105917",
326
+ "105755",
327
+ "109141",
328
+ "110497",
329
+ "112997",
330
+ "104810",
331
+ "108975",
332
+ "107130",
333
+ "107630",
334
+ "109395",
335
+ "115788",
336
+ "113845",
337
+ "114770",
338
+ "102313",
339
+ "104797",
340
+ "111189",
341
+ "105597",
342
+ "111140",
343
+ "106270",
344
+ "114836",
345
+ "108295",
346
+ "104518",
347
+ "110218",
348
+ "110784",
349
+ "101627",
350
+ "104280",
351
+ "107966",
352
+ "101228",
353
+ "104420",
354
+ "110543",
355
+ "114058",
356
+ "113394",
357
+ "107739",
358
+ "112657",
359
+ "111008",
360
+ "105911",
361
+ "111852",
362
+ "105465",
363
+ "114128",
364
+ "110280",
365
+ "112414",
366
+ "105302",
367
+ "107455",
368
+ "110327",
369
+ "114990",
370
+ "112730",
371
+ "104453",
372
+ "111691",
373
+ "114454",
374
+ "104474",
375
+ "104252",
376
+ "109654",
377
+ "104937",
378
+ "104871",
379
+ "107508",
380
+ "114525",
381
+ "115588",
382
+ "110540",
383
+ "109267",
384
+ "107539",
385
+ "108344",
386
+ "112659",
387
+ "112776",
388
+ "113046",
389
+ "107233",
390
+ "102035",
391
+ "106905",
392
+ "107997",
393
+ "112378",
394
+ "104520",
395
+ "106639",
396
+ "104670",
397
+ "104899",
398
+ "115628",
399
+ "108444",
400
+ "109923",
401
+ "110157",
402
+ "114304",
403
+ "114266"
404
+ ],
405
+ "val_patients": [
406
+ "109944",
407
+ "114903",
408
+ "112765",
409
+ "106200",
410
+ "106506",
411
+ "106536",
412
+ "112055",
413
+ "104447",
414
+ "106976",
415
+ "105978"
416
+ ],
417
+ "n_train": 80,
418
+ "n_val": 10
419
+ }
420
+ }
421
+ }
models/for_WMH_Vent/data_splits/public_fold_assignments.json ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "dataset": "Public_MSSEG",
4
+ "total_patients": 15,
5
+ "test_patients": 3,
6
+ "trainval_patients": 12,
7
+ "target_split": "60/20/20 (train/val/test)",
8
+ "n_folds": 4,
9
+ "random_seed": 42,
10
+ "center_balanced_test": true
11
+ },
12
+ "test_set": {
13
+ "patients": [
14
+ "c01p04",
15
+ "c07p05",
16
+ "c08p04"
17
+ ],
18
+ "n_patients": 3
19
+ },
20
+ "folds": {
21
+ "fold_0": {
22
+ "train_patients": [
23
+ "c08p03",
24
+ "c01p01",
25
+ "c08p02",
26
+ "c07p03",
27
+ "c07p04",
28
+ "c01p02",
29
+ "c07p01",
30
+ "c08p05",
31
+ "c07p02"
32
+ ],
33
+ "val_patients": [
34
+ "c01p05",
35
+ "c08p01",
36
+ "c01p03"
37
+ ],
38
+ "n_train": 9,
39
+ "n_val": 3
40
+ },
41
+ "fold_1": {
42
+ "train_patients": [
43
+ "c01p05",
44
+ "c08p01",
45
+ "c01p03",
46
+ "c07p03",
47
+ "c07p04",
48
+ "c01p02",
49
+ "c07p01",
50
+ "c08p05",
51
+ "c07p02"
52
+ ],
53
+ "val_patients": [
54
+ "c08p03",
55
+ "c01p01",
56
+ "c08p02"
57
+ ],
58
+ "n_train": 9,
59
+ "n_val": 3
60
+ },
61
+ "fold_2": {
62
+ "train_patients": [
63
+ "c01p05",
64
+ "c08p01",
65
+ "c01p03",
66
+ "c08p03",
67
+ "c01p01",
68
+ "c08p02",
69
+ "c07p01",
70
+ "c08p05",
71
+ "c07p02"
72
+ ],
73
+ "val_patients": [
74
+ "c07p03",
75
+ "c07p04",
76
+ "c01p02"
77
+ ],
78
+ "n_train": 9,
79
+ "n_val": 3
80
+ },
81
+ "fold_3": {
82
+ "train_patients": [
83
+ "c01p05",
84
+ "c08p01",
85
+ "c01p03",
86
+ "c08p03",
87
+ "c01p01",
88
+ "c08p02",
89
+ "c07p03",
90
+ "c07p04",
91
+ "c01p02"
92
+ ],
93
+ "val_patients": [
94
+ "c07p01",
95
+ "c08p05",
96
+ "c07p02"
97
+ ],
98
+ "n_train": 9,
99
+ "n_val": 3
100
+ }
101
+ }
102
+ }
models/for_WMH_Vent/download_models.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Visit our Hugging Face link for downloading the trained models.
models/for_WMH_Vent/folds_results_zscore2_all/per_class_summary.csv ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ Variant,Variant_Name,Class,Class_Name,DICE_mean,DICE_std,DICE_min,DICE_max,PRECISION_mean,PRECISION_std,PRECISION_min,PRECISION_max,RECALL_mean,RECALL_std,RECALL_min,RECALL_max,IOU_mean,IOU_std,IOU_min,IOU_max,SPECIFICITY_mean,SPECIFICITY_std,SPECIFICITY_min,SPECIFICITY_max,HD95_mean,HD95_std,HD95_min,HD95_max,LESION_SENSITIVITY_mean,LESION_SENSITIVITY_std,LESION_PRECISION_mean,LESION_PRECISION_std,LESION_F1_mean,LESION_F1_std,LESION_N_GT_LESIONS_total,LESION_N_PRED_LESIONS_total,LESION_TP_LESIONS_total,LESION_FN_LESIONS_total,LESION_FP_LESIONS_total
2
+ 1,unet,1,Ventricles,0.9296308495604303,0.003051861083997252,0.9245313971595007,0.9325869622190041,0.937810327296536,0.004534371323946414,0.9299792648109558,0.9408155762340407,0.9221807114485115,0.002258280011483868,0.9198278289806433,0.9258708828514376,0.86883963293893,0.005257697310767231,0.8600597389216186,0.8739402369187041,0.9992060262932462,5.3781685628696937e-05,0.9991132143255909,0.9992439887040112,1.0,0.0,1.0,1.0,,,,,,,0.0,0.0,0.0,0.0,0.0
3
+ 1,unet,2,Abnormal_WMH,0.8471261192911104,0.006988603634009174,0.8380055641483046,0.8562749203494235,0.8861666894324636,0.004785959852918547,0.8819829027399184,0.8938984076631564,0.8156915305742668,0.008049307835491817,0.8038219006844454,0.8241991234390766,0.7363711260717487,0.01045759631012397,0.7227128240374371,0.7500107261976053,0.9992840254178714,1.4410851399106235e-05,0.9992605553058871,0.9992976281722222,4.579276208116416,0.9906935564361317,3.105706418326917,5.669535448758443,,,,,,,,,,,
4
+ 2,attnunet,1,Ventricles,0.9104890513851166,0.024899999222747722,0.8675609526332009,0.9278078258646835,0.9203443411150141,0.02293220806285698,0.8806633355974736,0.9350793260117669,0.9019219497921485,0.026912289770452267,0.8562340091217527,0.923007569878129,0.83718265247343,0.040705456147877725,0.7670303812669205,0.8657288299651152,0.9989795287466985,0.000273307129456502,0.9985083429568933,0.9991706542188458,1.2282992876459566,0.3954259655345788,1.0,1.913197150583827,,,,,,,0.0,0.0,0.0,0.0,0.0
5
+ 2,attnunet,2,Abnormal_WMH,0.826975751920205,0.015579775036519973,0.8023896495263613,0.8453085695774089,0.8886304925692461,0.009863625617703263,0.8773743681724526,0.9034203955732797,0.779984526149581,0.01966711140506486,0.7519818345459007,0.8069007799264991,0.7066411846722295,0.022442058844092443,0.6715032946218304,0.7335430311874717,0.9993673813767945,5.564253319795054e-05,0.9993140913333589,0.9994591017880485,5.868210623237481,1.1565233310098124,4.299965705388233,7.125643309860791,,,,,,,,,,,
6
+ 3,dlv3unet,1,Ventricles,0.9005661992435416,0.0020867923289419102,0.8974816472833985,0.9032637548534808,0.8997698242284116,0.002619190013214462,0.8961697010761339,0.9031854303942591,0.9018365317555639,0.00304923631516394,0.896838192141597,0.9048159553177055,0.8198187029362412,0.0034773428576675017,0.8147050743404023,0.8243681794688564,0.9987641261362938,3.990839508953154e-05,0.9987224004496448,0.9988270370474228,1.0,0.0,1.0,1.0,,,,,,,0.0,0.0,0.0,0.0,0.0
7
+ 3,dlv3unet,2,Abnormal_WMH,0.7763168733853871,0.003073255677872925,0.772352120446274,0.7808450853012623,0.7932948495860329,0.01339398897949029,0.775105014208068,0.8127526621991921,0.7653741420470819,0.01188742959950416,0.7489376058870139,0.7803045906209611,0.6370210758311682,0.003999380279290716,0.6319568324618485,0.6429538443669356,0.9985668433084038,0.00013274507182798207,0.9983882554846853,0.9987628711423673,4.7126929962683395,0.5556289444958085,4.095494923513843,5.423612659365294,,,,,,,,,,,
8
+ 4,transunet,1,Ventricles,0.9246872887842248,0.004597522753464204,0.917144392619005,0.9284374594079503,0.9320059959760637,0.011631626135529186,0.9158405109177434,0.9481783808085702,0.9184641125298365,0.004922383784531681,0.9100900076594562,0.9224872999070102,0.8603159951545595,0.007862485177144832,0.8474580455251116,0.8667238350341302,0.9991215386213639,0.00017857472387549319,0.9988581545201223,0.999358864513995,1.0,0.0,1.0,1.0,,,,,,,0.0,0.0,0.0,0.0,0.0
9
+ 4,transunet,2,Abnormal_WMH,0.8322919090444327,0.010816310171427137,0.81389122085058,0.8417282856836877,0.9035192038694635,0.003183241810633453,0.8989558813413554,0.9074719810187692,0.7761566255927599,0.015082099936625985,0.7515534359141638,0.7926254513986619,0.7142798166712054,0.01573238533985691,0.6875437726337009,0.7281151773663771,0.9994577459658872,1.5563988190790074e-05,0.9994339623025651,0.9994755539827085,5.929181221900818,1.9288286807668098,4.026591793193768,8.744998558832915,,,,,,,,,,,
models/for_WMH_Vent/folds_results_zscore2_all/test_metrics_all_variants_folds.csv ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Variant,Variant_Name,Fold,Test_Samples,DICE_class_1,DICE_class_2,DICE_mean,PRECISION_class_1,PRECISION_class_2,PRECISION_mean,RECALL_class_1,RECALL_class_2,RECALL_mean,IOU_class_1,IOU_class_2,IOU_mean,SPECIFICITY_class_1,SPECIFICITY_class_2,SPECIFICITY_mean,HD95_class_1,HD95_class_2,HD95_mean,LESION_LESION_SENSITIVITY_class_0,LESION_LESION_PRECISION_class_0,LESION_LESION_F1_class_0,LESION_N_GT_LESIONS_class_0,LESION_N_PRED_LESIONS_class_0,LESION_TP_LESIONS_class_0,LESION_FN_LESIONS_class_0,LESION_FP_LESIONS_class_0,LESION_LESION_SENSITIVITY_class_1,LESION_LESION_PRECISION_class_1,LESION_LESION_F1_class_1,LESION_N_GT_LESIONS_class_1,LESION_N_PRED_LESIONS_class_1,LESION_TP_LESIONS_class_1,LESION_FN_LESIONS_class_1,LESION_FP_LESIONS_class_1,LESION_LESION_SENSITIVITY_mean,LESION_LESION_PRECISION_mean,LESION_LESION_F1_mean,LESION_N_GT_LESIONS_total,LESION_N_PRED_LESIONS_total,LESION_TP_LESIONS_total,LESION_FN_LESIONS_total,LESION_FP_LESIONS_total
2
+ 1,unet,0,70,0.924531397,0.843338613,0.883935005,0.929979265,0.882387206,0.906183236,0.919827829,0.812886653,0.866357241,0.860059739,0.730685865,0.795372802,0.999113214,0.999260555,0.999186885,1,5.669535449,3.334767724,,,,,,,,,,,,,,,,,0.810285987,0.717464393,0.753906356,275,309,226,49,84
3
+ 1,unet,1,70,0.931030738,0.850885379,0.890958059,0.940815576,0.886398241,0.913606909,0.921931425,0.821858445,0.871894935,0.87127183,0.742075089,0.806673459,0.999243989,0.999293762,0.999268875,1,3.105706418,2.052853209,,,,,,,,,,,,,,,,,0.831870007,0.757975006,0.788804562,275,308,230,45,76
4
+ 1,unet,2,70,0.932586962,0.85627492,0.894430941,0.939876251,0.893898408,0.916887329,0.925870883,0.824199123,0.875035003,0.873940237,0.750010726,0.811975482,0.999235735,0.999297628,0.999266682,1,4.274767062,2.637383531,,,,,,,,,,,,,,,,,0.8190735,0.761199505,0.785236019,275,299,227,48,69
5
+ 1,unet,3,70,0.930374301,0.838005564,0.884189933,0.940570217,0.881982903,0.91127656,0.921092709,0.803821901,0.862457305,0.870086726,0.722712824,0.796399775,0.999231167,0.999284157,0.999257662,1,5.267095904,3.133547952,,,,,,,,,,,,,,,,,0.803511136,0.755088192,0.768750434,275,312,221,54,92
6
+ 2,attnunet,0,70,0.927807826,0.84530857,0.886558198,0.933405023,0.891192522,0.912298773,0.92300757,0.80690078,0.864954175,0.86572883,0.733543031,0.799635931,0.9991422,0.99933414,0.99923817,1,6.817768379,3.90888419,,,,,,,,,,,,,,,,,0.805918858,0.737682024,0.763465574,275,311,224,51,81
7
+ 2,attnunet,1,70,0.921130442,0.827778912,0.874454677,0.935079326,0.882534684,0.908807005,0.908634865,0.784884088,0.846759477,0.854381693,0.707421007,0.78090135,0.999170654,0.999314091,0.999242373,1,4.299965705,2.649982853,,,,,,,,,,,,,,,,,0.797643669,0.741290568,0.757996731,275,306,221,54,83
8
+ 2,attnunet,2,70,0.925456985,0.832425877,0.878941431,0.93222968,0.903420396,0.917825038,0.919811356,0.776171402,0.847991379,0.861589705,0.714097406,0.787843556,0.999096918,0.999459102,0.99927801,1,5.229465099,3.114732549,,,,,,,,,,,,,,,,,0.800801564,0.781510442,0.783598099,275,291,222,53,66
9
+ 2,attnunet,3,70,0.867560953,0.80238965,0.834975301,0.880663336,0.877374368,0.879018852,0.856234009,0.751981835,0.804107922,0.767030381,0.671503295,0.719266838,0.998508343,0.999362192,0.998935268,1.913197151,7.12564331,4.51942023,,,,,,,,,,,,,,,,,0.800023323,0.639032382,0.699206596,275,339,222,53,112
10
+ 3,dlv3unet,0,70,0.900234027,0.780845085,0.840539556,0.896169701,0.794573359,0.84537153,0.904815955,0.772086706,0.838451331,0.819275677,0.642953844,0.731114761,0.9987224,0.998555632,0.998639016,1,5.423612659,3.21180633,,,,,,,,,,,,,,,,,0.753118133,0.708678143,0.719359836,275,287,209,66,85
11
+ 3,dlv3unet,1,70,0.903263755,0.776870874,0.840067314,0.90318543,0.812752662,0.857969046,0.90365004,0.748937606,0.826293823,0.824368179,0.63776127,0.731064725,0.998827037,0.998762871,0.998794954,1,4.251005732,2.625502866,,,,,,,,,,,,,,,,,0.730627383,0.773161114,0.746853388,275,257,200,75,65
12
+ 3,dlv3unet,2,70,0.897481647,0.77235212,0.834916884,0.898679638,0.790748363,0.844714001,0.896838192,0.760167666,0.828502929,0.814705074,0.631956832,0.723330953,0.998738575,0.998560615,0.998649595,1,5.08065867,3.040329335,,,,,,,,,,,,,,,,,0.710273969,0.713006921,0.702347611,275,273,196,79,86
13
+ 3,dlv3unet,3,70,0.901285368,0.775199414,0.838242391,0.901044527,0.775105014,0.838074771,0.902041939,0.780304591,0.841173265,0.820925881,0.635412356,0.728169118,0.998768492,0.998388255,0.998578374,1,4.095494924,2.547747462,,,,,,,,,,,,,,,,,0.686803058,0.707394828,0.686217762,275,264,189,86,84
14
+ 4,transunet,0,70,0.928372145,0.841728286,0.885050215,0.948178381,0.902428523,0.925303452,0.910090008,0.792625451,0.85135773,0.866710076,0.728115177,0.797412627,0.999358865,0.999433962,0.999396413,1,8.744998559,4.872499279,,,,,,,,,,,,,,,,,0.791318113,0.78846014,0.778279842,275,299,224,51,71
15
+ 4,transunet,1,70,0.928437459,0.837057166,0.882747312,0.935028265,0.907471981,0.921250123,0.9224873,0.779866141,0.851176721,0.866723835,0.721270863,0.793997349,0.999162149,0.99945489,0.999308519,1,6.680013727,3.840006863,,,,,,,,,,,,,,,,,0.769424443,0.762231589,0.753840342,275,297,214,61,74
16
+ 4,transunet,2,70,0.924795158,0.836490964,0.880643061,0.928976828,0.90522043,0.917098629,0.921395395,0.780581474,0.850988434,0.860372024,0.720189454,0.790280739,0.999106986,0.999466578,0.999286782,1,4.026591793,2.513295897,,,,,,,,,,,,,,,,,0.768415042,0.766517149,0.760036199,275,282,215,60,64
17
+ 4,transunet,3,70,0.917144393,0.813891221,0.865517807,0.915840511,0.898955881,0.907398196,0.919883748,0.751553436,0.835718592,0.847458046,0.687543773,0.767500909,0.998858155,0.999475554,0.999166854,1,4.265120809,2.632560404,,,,,,,,,,,,,,,,,0.810819711,0.695502554,0.740358673,275,330,225,50,98
18
+ ,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
19
+ ,unet - mean,,,0.9296,0.8471,0.8884,0.9378,0.8862,0.912,0.9222,0.8157,0.8689,0.8688,0.7364,0.8026,0.9992,0.9993,0.9992,1,4.6,2.8,,,,,,,,,,,,,,,,,0.8162,0.7479,0.7742,275,307,226,49,80.25
20
+ ,attn - mean,,,0.9105,0.827,0.8687,0.9203,0.8886,0.9045,0.9019,0.78,0.841,0.8372,0.7066,0.7719,0.999,0.9994,0.9992,1.2,5.9,3.5,,,,,,,,,,,,,,,,,0.8011,0.7249,0.7511,275,311.75,222.25,52.75,85.5
21
+ ,dlv3 - mean,,,0.9006,0.7763,0.8384,0.8998,0.7933,0.8465,0.9018,0.7654,0.8336,0.8198,0.637,0.7284,0.9988,0.9986,0.9987,1,4.7,2.9,,,,,,,,,,,,,,,,,0.7202,0.7256,0.7137,275,270.25,198.5,76.5,80
22
+ ,trans - mean,,,0.9247,0.8323,0.8785,0.932,0.9035,0.9178,0.9185,0.7762,0.8473,0.8603,0.7143,0.7873,0.9991,0.9995,0.9993,1,5.9,3.5,,,,,,,,,,,,,,,,,0.785,0.7532,0.7581,275,302,219.5,55.5,76.75
23
+ ,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
24
+ ,unet - std,,,0.0031,0.007,0.0045,0.0045,0.0048,0.0039,0.0023,0.008,0.0049,0.0053,0.0105,0.007,0.0001,0,0,0,1,0.5,,,,,,,,,,,,,,,,,0.0106,0.0177,0.0139,0,4.8477,3.2404,3.2404,8.6132
25
+ ,attn - std,,,0.0249,0.0156,0.02,0.0229,0.0099,0.0151,0.0269,0.0197,0.0225,0.0407,0.0224,0.0311,0.0003,0.0001,0.0001,0.4,1.2,0.7,,,,,,,,,,,,,,,,,0.003,0.0525,0.0314,0,17.3692,1.0897,1.0897,16.6508
26
+ ,dlv3 - std,,,0.0021,0.0031,0.0022,0.0026,0.0134,0.0072,0.003,0.0119,0.0063,0.0035,0.004,0.0032,0,0.0001,0.0001,0,0.6,0.3,,,,,,,,,,,,,,,,,0.0245,0.0276,0.0224,0,11.211,7.2284,7.2284,8.6891
27
+ ,trans - std,,,0.0046,0.0108,0.0076,0.0116,0.0032,0.0066,0.0049,0.0151,0.0067,0.0079,0.0157,0.0117,0.0002,0,0.0001,0,1.9,1,,,,,,,,,,,,,,,,,0.0175,0.0348,0.0136,0,17.4499,5.0249,5.0249,12.794
models/for_WMH_Vent/folds_results_zscore2_all/training_info_all_variants_folds.csv ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Variant,Variant_Name,Fold,Best_Epoch,Composite_Score,Total_Epochs,First_Valid_Epoch,Total_Valid_Epochs,Best_Epoch_Val_Loss,Best_Epoch_Dice_Ventricles,Best_Epoch_Dice_Abnormal_WMH,Best_Epoch_Dice_Mean,Best_Abnormal_Epoch,Best_Abnormal_Dice,Best_Ventricles_Epoch,Best_Ventricles_Dice
2
+ 1,unet,0,49,0.837773437480731,60,1,60,0.24741753935813904,0.9308493801953847,0.8054339622632157,0.9115566193830764,43,0.8058087693216404,49,0.9308493801953847
3
+ 1,unet,1,45,0.8509202240606865,60,1,60,0.3080134391784668,0.9268369262837436,0.8394508168223501,0.9215010850395001,28,0.8441074580031014,38,0.9274915960857308
4
+ 1,unet,2,36,0.8128944361644407,60,1,60,0.27736401557922363,0.9342240045327603,0.7672727272708917,0.9000714500704411,32,0.7696575927137208,34,0.9378331718769447
5
+ 1,unet,3,41,0.8148548201069025,60,1,60,0.3056482672691345,0.9412208603997376,0.7717556478564912,0.9037425888471717,41,0.7717556478564912,44,0.9415513142951589
6
+ 2,attnunet,0,38,0.8465806985395226,60,1,60,0.2354777455329895,0.9361820594989245,0.8154564254052402,0.9167136699540254,38,0.8154564254052402,49,0.9369088654755128
7
+ 2,attnunet,1,42,0.8468065449382642,60,1,60,0.3282952904701233,0.9189075870475,0.8399396631183776,0.9189869898404228,42,0.8399396631183776,42,0.9189075870475
8
+ 2,attnunet,2,35,0.8082210232243792,60,1,60,0.2833690643310547,0.9301114433264854,0.7625408277658984,0.8971071685730373,35,0.7625408277658984,38,0.932546742487403
9
+ 2,attnunet,3,35,0.7675559444491301,60,1,60,0.3719373941421509,0.8997412800024551,0.7247121664376812,0.8740189336882455,35,0.7247121664376812,51,0.9082138618936411
10
+ 3,dlv3unet,0,41,0.7945477803722963,60,1,60,0.3116353750228882,0.8988588122221663,0.7600894570132255,0.8856122734890453,41,0.7600894570132255,54,0.9004052827384709
11
+ 3,dlv3unet,1,42,0.8150221762616997,60,1,60,0.3728603720664978,0.9037163637265287,0.8019888405839849,0.9011564065443626,42,0.8019888405839849,40,0.9049275398385249
12
+ 3,dlv3unet,2,28,0.7727672322932403,60,1,60,0.34316256642341614,0.9029322657428571,0.7270063486878747,0.8760687237787342,34,0.7281193622294404,40,0.9059795923856953
13
+ 3,dlv3unet,3,28,0.768303148626621,60,1,60,0.3877088725566864,0.9124585568837787,0.7222274480285933,0.8774333183819368,28,0.7222274480285933,37,0.916682545438336
14
+ 4,transunet,0,39,0.8410510311241638,60,1,60,0.24608999490737915,0.9346885813142664,0.808755760367703,0.9139696605000669,35,0.8090334741168289,45,0.9357508099451346
15
+ 4,transunet,1,48,0.8483119522767122,60,1,60,0.3149019777774811,0.9253444084272909,0.8369980458771218,0.9201816141429934,29,0.8392693984451316,50,0.9254004460337096
16
+ 4,transunet,2,35,0.8109694312756469,60,1,60,0.27634185552597046,0.9331867846270409,0.7644126357335528,0.89876925662194,35,0.7644126357335528,55,0.9336813537844856
17
+ 4,transunet,3,28,0.7773769906034641,60,1,60,0.34197694063186646,0.9332184349847285,0.7193485902853871,0.883513943096132,28,0.7193485902853871,50,0.9413152166406702
models/for_WMH_Vent/folds_results_zscore2_all/variant_comparison_test.csv ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ Variant,Variant_Name,N_Folds,DICE_Mean,DICE_Std,DICE_Class1_Mean,DICE_Class1_Std,DICE_Class2_Mean,DICE_Class2_Std,PRECISION_Mean,PRECISION_Std,PRECISION_Class1_Mean,PRECISION_Class1_Std,PRECISION_Class2_Mean,PRECISION_Class2_Std,RECALL_Mean,RECALL_Std,RECALL_Class1_Mean,RECALL_Class1_Std,RECALL_Class2_Mean,RECALL_Class2_Std,IOU_Mean,IOU_Std,IOU_Class1_Mean,IOU_Class1_Std,IOU_Class2_Mean,IOU_Class2_Std,SPECIFICITY_Mean,SPECIFICITY_Std,SPECIFICITY_Class1_Mean,SPECIFICITY_Class1_Std,SPECIFICITY_Class2_Mean,SPECIFICITY_Class2_Std,HD95_Mean,HD95_Std,HD95_Class1_Mean,HD95_Class1_Std,HD95_Class2_Mean,HD95_Class2_Std,LESION_SENSITIVITY_Mean,LESION_SENSITIVITY_Std,LESION_PRECISION_Mean,LESION_PRECISION_Std,LESION_F1_Mean,LESION_F1_Std,LESION_N_GT_LESIONS_Total,LESION_N_PRED_LESIONS_Total,LESION_TP_LESIONS_Total,LESION_FN_LESIONS_Total,LESION_FP_LESIONS_Total
2
+ 1,unet,4,0.88837848442577,0.004488176438408171,0.9296308495604303,0.003051861083997252,0.8471261192911104,0.006988603634009174,0.9119885083644996,0.003899542620426997,0.937810327296536,0.004534371323946414,0.8861666894324636,0.004785959852918547,0.8689361210113892,0.004862525795740292,0.9221807114485115,0.002258280011483868,0.8156915305742668,0.008049307835491817,0.8026053795053394,0.006985123311053306,0.86883963293893,0.005257697310767231,0.7363711260717487,0.01045759631012397,0.9992450258555589,3.3829763475009536e-05,0.9992060262932462,5.3781685628696937e-05,0.9992840254178714,1.4410851399106235e-05,2.789638104058208,0.4953467782180657,1.0,0.0,4.579276208116416,0.9906935564361317,0.8161851575612329,0.010604103780010409,0.7479317737104494,0.01772268907341698,0.7741743428489778,0.01393389833458131,1100,1228,904,196,321
3
+ 2,attnunet,4,0.8687324016526609,0.019964152607882032,0.9104890513851166,0.024899999222747722,0.826975751920205,0.015579775036519973,0.9044874168421302,0.015051711319037623,0.9203443411150141,0.02293220806285698,0.8886304925692461,0.009863625617703263,0.8409532379708647,0.02245478860306288,0.9019219497921485,0.026912289770452267,0.779984526149581,0.01966711140506486,0.7719119185728298,0.031123754171912342,0.83718265247343,0.040705456147877725,0.7066411846722295,0.022442058844092443,0.9991734550617465,0.000138385919780625,0.9989795287466985,0.000273307129456502,0.9993673813767945,5.564253319795054e-05,3.5482549554417186,0.7190357923335028,1.2282992876459566,0.3954259655345788,5.868210623237481,1.1565233310098124,0.8010968535966527,0.0030172783603819066,0.7248788539521523,0.05246431861595214,0.7510667501711839,0.0314226022122775,1100,1247,889,211,342
4
+ 3,dlv3unet,4,0.8384415363144644,0.0022083747230999306,0.9005661992435416,0.0020867923289419102,0.7763168733853871,0.003073255677872925,0.8465323369072222,0.0071934443649559685,0.8997698242284116,0.002619190013214462,0.7932948495860329,0.01339398897949029,0.8336053369013228,0.0063294943822809775,0.9018365317555639,0.00304923631516394,0.7653741420470819,0.01188742959950416,0.7284198893837046,0.003170869073367662,0.8198187029362412,0.0034773428576675017,0.6370210758311682,0.003999380279290716,0.9986654847223488,7.953573517565171e-05,0.9987641261362938,3.990839508953154e-05,0.9985668433084038,0.00013274507182798207,2.8563464981341697,0.27781447224790423,1.0,0.0,4.7126929962683395,0.5556289444958085,0.720205635740483,0.024526594995177977,0.7255602516176558,0.02756091610984561,0.7136946491830041,0.022446218734487867,1100,1081,794,306,320
5
+ 4,transunet,4,0.8784895989143288,0.007649748417636039,0.9246872887842248,0.004597522753464204,0.8322919090444327,0.010816310171427137,0.9177625999227637,0.00664998101397249,0.9320059959760637,0.011631626135529186,0.9035192038694635,0.003183241810633453,0.8473103690612981,0.006693789315880423,0.9184641125298365,0.004922383784531681,0.7761566255927599,0.015082099936625985,0.7872979059128824,0.011704790584520298,0.8603159951545595,0.007862485177144832,0.7142798166712054,0.01573238533985691,0.9992896422936255,8.191687124816227e-05,0.9991215386213639,0.00017857472387549319,0.9994577459658872,1.5563988190790074e-05,3.4645906109504088,0.964414340383405,1.0,0.0,5.929181221900818,1.9288286807668098,0.784994327470758,0.01749453278425242,0.7531778580152171,0.03475336972283068,0.7581287639607504,0.013636999062794438,1100,1208,878,222,307
models/for_WMH_Vent/folds_results_zscore2_all/variant_comparison_training.csv ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ Variant,Variant_Name,N_Folds,Best_Epoch_Mean,Best_Epoch_Std,Best_Epoch_Min,Best_Epoch_Max,Composite_Score_Mean,Composite_Score_Std,Best_Epoch_Val_Loss_Mean,Best_Epoch_Val_Loss_Std,Best_Epoch_Dice_Mean_Mean,Best_Epoch_Dice_Mean_Std,Best_Epoch_Dice_Ventricles_Mean,Best_Epoch_Dice_Ventricles_Std,Best_Epoch_Dice_Abnormal_WMH_Mean,Best_Epoch_Dice_Abnormal_WMH_Std
2
+ 1,unet,4,42.75,4.815340071064556,36,49,0.8291107294531901,0.0159444009351335,0.284610815346241,0.02462779461924618,0.9092179358350473,0.00821557699733339,0.9332827928529066,0.005276587173391204,0.7959782885532373,0.02911192217235299
3
+ 2,attnunet,4,37.5,2.8722813232690143,35,42,0.8172910527878241,0.03272955110785234,0.3047698736190796,0.05080431749397413,0.9017066905139327,0.018107910266314746,0.9212355924688412,0.013870856883885596,0.7856622706817994,0.044953429555710196
4
+ 3,dlv3unet,4,34.75,6.7592529172978875,28,42,0.7876600843884642,0.018658861388516038,0.35384179651737213,0.029172388598382663,0.8850676805485197,0.009980085098274465,0.9044914996438328,0.0049556335292024285,0.7528280235784197,0.03190873102426834
5
+ 4,transunet,4,37.5,7.22841614740048,28,48,0.8194273513199967,0.028025314220223092,0.2948276922106743,0.0365483168698032,0.904108618590283,0.01421469994866525,0.9316095523383316,0.0036677177073460732,0.7823787580659411,0.0446502980212195
models/for_WMH_Vent/model_training_scripts/attn_unet_model.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###################### Libraries ######################
2
+ # Deep Learning
3
+ import keras
4
+ from keras.models import Model
5
+ from keras.layers import Input, Conv2D, MaxPooling2D, Conv2DTranspose, concatenate
6
+
7
+
8
+ def build_attention_unet_3class(input_shape=(256, 256, 1), num_classes=3):
9
+ """Enhanced Attention U-Net architecture with dropout"""
10
+
11
+ def attention_block(F_g, F_l, F_int):
12
+ """Attention gate implementation"""
13
+ W_g = Conv2D(F_int, 1, padding='same')(F_g)
14
+ W_x = Conv2D(F_int, 1, padding='same')(F_l)
15
+ psi = keras.layers.Add()([W_g, W_x])
16
+ psi = keras.layers.Activation('relu')(psi)
17
+ psi = Conv2D(1, 1, padding='same')(psi)
18
+ psi = keras.layers.Activation('sigmoid')(psi)
19
+ return keras.layers.Multiply()([F_l, psi])
20
+
21
+ inputs = Input(input_shape)
22
+
23
+ # Encoder with dropout (matching your original dropout pattern)
24
+ c1 = Conv2D(64, 3, activation='relu', padding='same')(inputs)
25
+ c1 = Conv2D(64, 3, activation='relu', padding='same')(c1)
26
+ p1 = MaxPooling2D(2)(c1)
27
+ p1 = keras.layers.Dropout(0.1)(p1)
28
+
29
+ c2 = Conv2D(128, 3, activation='relu', padding='same')(p1)
30
+ c2 = Conv2D(128, 3, activation='relu', padding='same')(c2)
31
+ p2 = MaxPooling2D(2)(c2)
32
+ p2 = keras.layers.Dropout(0.1)(p2)
33
+
34
+ c3 = Conv2D(256, 3, activation='relu', padding='same')(p2)
35
+ c3 = Conv2D(256, 3, activation='relu', padding='same')(c3)
36
+ p3 = MaxPooling2D(2)(c3)
37
+ p3 = keras.layers.Dropout(0.2)(p3)
38
+
39
+ c4 = Conv2D(512, 3, activation='relu', padding='same')(p3)
40
+ c4 = Conv2D(512, 3, activation='relu', padding='same')(c4)
41
+ p4 = MaxPooling2D(2)(c4)
42
+ p4 = keras.layers.Dropout(0.2)(p4)
43
+
44
+ # Bridge
45
+ c5 = Conv2D(1024, 3, activation='relu', padding='same')(p4)
46
+ c5 = Conv2D(1024, 3, activation='relu', padding='same')(c5)
47
+ c5 = keras.layers.Dropout(0.3)(c5)
48
+
49
+ # Decoder with attention gates (using Conv2DTranspose - more standard)
50
+ u6 = Conv2DTranspose(512, 2, strides=2, padding='same')(c5)
51
+ att6 = attention_block(u6, c4, 256)
52
+ u6 = concatenate([u6, att6])
53
+ u6 = keras.layers.Dropout(0.2)(u6)
54
+ c6 = Conv2D(512, 3, activation='relu', padding='same')(u6)
55
+ c6 = Conv2D(512, 3, activation='relu', padding='same')(c6)
56
+
57
+ u7 = Conv2DTranspose(256, 2, strides=2, padding='same')(c6)
58
+ att7 = attention_block(u7, c3, 128)
59
+ u7 = concatenate([u7, att7])
60
+ u7 = keras.layers.Dropout(0.2)(u7)
61
+ c7 = Conv2D(256, 3, activation='relu', padding='same')(u7)
62
+ c7 = Conv2D(256, 3, activation='relu', padding='same')(c7)
63
+
64
+ u8 = Conv2DTranspose(128, 2, strides=2, padding='same')(c7)
65
+ att8 = attention_block(u8, c2, 64)
66
+ u8 = concatenate([u8, att8])
67
+ u8 = keras.layers.Dropout(0.1)(u8)
68
+ c8 = Conv2D(128, 3, activation='relu', padding='same')(u8)
69
+ c8 = Conv2D(128, 3, activation='relu', padding='same')(c8)
70
+
71
+ u9 = Conv2DTranspose(64, 2, strides=2, padding='same')(c8)
72
+ att9 = attention_block(u9, c1, 32)
73
+ u9 = concatenate([u9, att9])
74
+ u9 = keras.layers.Dropout(0.1)(u9)
75
+ c9 = Conv2D(64, 3, activation='relu', padding='same')(u9)
76
+ c9 = Conv2D(64, 3, activation='relu', padding='same')(c9)
77
+
78
+ # Output layer - preserving your original conditional logic
79
+ if num_classes == 1:
80
+ outputs = Conv2D(1, 1, activation='sigmoid')(c9)
81
+ else:
82
+ outputs = Conv2D(num_classes, 1, activation='softmax')(c9)
83
+
84
+ return Model(inputs, outputs)
85
+
models/for_WMH_Vent/model_training_scripts/base_runner_all.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import sys
3
+ import numpy as np
4
+
5
+
6
+ # Run scripts one after another
7
+
8
+
9
+ for fold in range(4):
10
+
11
+ # Skip folds:
12
+ # if fold in list(np.array([0])):
13
+ # continue
14
+
15
+ for variant in range(5):
16
+
17
+ # Skip variants:
18
+ if variant in list(np.array([0])):
19
+ continue
20
+
21
+ # subprocess.run([sys.executable, "p4_run_experiments_all.py", "--variant", str(variant), "--fold", str(fold), "--scenario", "standard_3class"])
22
+
23
+
models/for_WMH_Vent/model_training_scripts/dlv3_unet_model.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###################### Libraries ######################
2
+ # Deep Learning
3
+ import tensorflow as tf
4
+ import keras
5
+ from keras.models import Model, load_model
6
+ from keras.layers import Input, Conv2D, MaxPooling2D, Conv2DTranspose, concatenate
7
+ from keras import backend as K
8
+ from tensorflow.keras import layers, optimizers, callbacks
9
+ from keras.utils import to_categorical
10
+
11
+
12
+ def build_deeplabv3_unet_3class(input_shape=(256, 256, 1), num_classes=3):
13
+ """
14
+ Standard DeepLabV3+ implementation with ResNet-50 backbone
15
+ Following the original paper: "Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation"
16
+ """
17
+
18
+ def conv_block(x, filters, kernel_size=3, strides=1, dilation_rate=1, use_bias=False, name=None):
19
+ """Standard convolution block with BN and ReLU"""
20
+ x = layers.Conv2D(filters, kernel_size, strides=strides, padding='same',
21
+ dilation_rate=dilation_rate, use_bias=use_bias, name=name)(x)
22
+ # x = layers.BatchNormalization()(x)
23
+ x = layers.Activation('relu')(x)
24
+ return x
25
+
26
+ def bottleneck_residual_block(x, filters, strides=1, dilation_rate=1, projection_shortcut=False, name_prefix=""):
27
+ """ResNet-50 bottleneck block with optional atrous convolution"""
28
+ shortcut = x
29
+
30
+ # Projection shortcut if needed
31
+ if projection_shortcut:
32
+ shortcut = layers.Conv2D(filters * 4, 1, strides=strides, use_bias=False,
33
+ name=f"{name_prefix}_0_conv")(shortcut)
34
+ # shortcut = layers.BatchNormalization(name=f"{name_prefix}_0_bn")(shortcut)
35
+
36
+ # Bottleneck layers
37
+ x = layers.Conv2D(filters, 1, use_bias=False, name=f"{name_prefix}_1_conv")(x)
38
+ # x = layers.BatchNormalization(name=f"{name_prefix}_1_bn")(x)
39
+ x = layers.Activation('relu')(x)
40
+
41
+ x = layers.Conv2D(filters, 3, strides=strides, padding='same',
42
+ dilation_rate=dilation_rate, use_bias=False, name=f"{name_prefix}_2_conv")(x)
43
+ # x = layers.BatchNormalization(name=f"{name_prefix}_2_bn")(x)
44
+ x = layers.Activation('relu')(x)
45
+
46
+ x = layers.Conv2D(filters * 4, 1, use_bias=False, name=f"{name_prefix}_3_conv")(x)
47
+ # x = layers.BatchNormalization(name=f"{name_prefix}_3_bn")(x)
48
+
49
+ x = layers.Add()([shortcut, x])
50
+ x = layers.Activation('relu')(x)
51
+
52
+ return x
53
+
54
+ def aspp_block(x, filters=256):
55
+ """Atrous Spatial Pyramid Pooling with proper implementation"""
56
+
57
+ # ASPP branches
58
+ # 1x1 convolution
59
+ b1 = layers.Conv2D(filters, 1, use_bias=False, name='aspp_1x1')(x)
60
+ # b1 = layers.BatchNormalization(name='aspp_1x1_bn')(b1)
61
+ b1 = layers.Activation('relu')(b1)
62
+
63
+ # 3x3 convolution with rate = 6
64
+ b2 = layers.Conv2D(filters, 3, padding='same', dilation_rate=6, use_bias=False, name='aspp_3x3_6')(x)
65
+ # b2 = layers.BatchNormalization(name='aspp_3x3_6_bn')(b2)
66
+ b2 = layers.Activation('relu')(b2)
67
+
68
+ # 3x3 convolution with rate = 12
69
+ b3 = layers.Conv2D(filters, 3, padding='same', dilation_rate=12, use_bias=False, name='aspp_3x3_12')(x)
70
+ # b3 = layers.BatchNormalization(name='aspp_3x3_12_bn')(b3)
71
+ b3 = layers.Activation('relu')(b3)
72
+
73
+ # 3x3 convolution with rate = 18
74
+ b4 = layers.Conv2D(filters, 3, padding='same', dilation_rate=18, use_bias=False, name='aspp_3x3_18')(x)
75
+ # b4 = layers.BatchNormalization(name='aspp_3x3_18_bn')(b4)
76
+ b4 = layers.Activation('relu')(b4)
77
+
78
+ # Image-level features (Global Average Pooling) - Simplified approach
79
+ # Get input spatial dimensions
80
+ input_shape = tf.shape(x)
81
+ h, w = input_shape[1], input_shape[2]
82
+
83
+ b5 = layers.GlobalAveragePooling2D(name='aspp_gap')(x)
84
+ b5 = layers.Reshape((1, 1, -1))(b5)
85
+ b5 = layers.Conv2D(filters, 1, use_bias=False, name='aspp_gap_conv')(b5)
86
+ # b5 = layers.BatchNormalization(name='aspp_gap_bn')(b5)
87
+ b5 = layers.Activation('relu')(b5)
88
+
89
+ # Use a resize function that handles KerasTensors properly
90
+ def resize_to_input_shape(args):
91
+ features, spatial_shape = args
92
+ return tf.image.resize(features, spatial_shape, method='bilinear')
93
+
94
+ b5 = layers.Lambda(resize_to_input_shape, name='aspp_gap_resize')([b5, [h, w]])
95
+
96
+ # Concatenate all branches
97
+ concat_features = layers.Concatenate(name='aspp_concat')([b1, b2, b3, b4, b5])
98
+
99
+ # Final 1x1 convolution
100
+ output = layers.Conv2D(filters, 1, use_bias=False, name='aspp_final_conv')(concat_features)
101
+ # output = layers.BatchNormalization(name='aspp_final_bn')(output)
102
+ output = layers.Activation('relu')(output)
103
+ output = layers.Dropout(0.1, name='aspp_dropout')(output)
104
+
105
+ return output
106
+
107
+ # Input layer
108
+ inputs = layers.Input(input_shape, name='input')
109
+
110
+ # ==================== ENCODER (ResNet-50 Backbone) ====================
111
+
112
+ # Initial convolution
113
+ x = layers.Conv2D(64, 7, strides=2, padding='same', use_bias=False, name='conv1')(inputs)
114
+ # x = layers.BatchNormalization(name='conv1_bn')(x)
115
+ x = layers.Activation('relu')(x)
116
+ x = layers.MaxPooling2D(3, strides=2, padding='same', name='pool1')(x)
117
+
118
+ # Stage 1 (conv2_x) - Low-level features for decoder
119
+ x = bottleneck_residual_block(x, 64, strides=1, projection_shortcut=True, name_prefix='conv2_block1')
120
+ x = bottleneck_residual_block(x, 64, name_prefix='conv2_block2')
121
+ low_level_features = bottleneck_residual_block(x, 64, name_prefix='conv2_block3')
122
+
123
+ # Stage 2 (conv3_x)
124
+ x = bottleneck_residual_block(low_level_features, 128, strides=2, projection_shortcut=True, name_prefix='conv3_block1')
125
+ x = bottleneck_residual_block(x, 128, name_prefix='conv3_block2')
126
+ x = bottleneck_residual_block(x, 128, name_prefix='conv3_block3')
127
+ x = bottleneck_residual_block(x, 128, name_prefix='conv3_block4')
128
+
129
+ # Stage 3 (conv4_x) - With atrous convolution
130
+ x = bottleneck_residual_block(x, 256, strides=1, dilation_rate=2, projection_shortcut=True, name_prefix='conv4_block1')
131
+ x = bottleneck_residual_block(x, 256, dilation_rate=2, name_prefix='conv4_block2')
132
+ x = bottleneck_residual_block(x, 256, dilation_rate=2, name_prefix='conv4_block3')
133
+ x = bottleneck_residual_block(x, 256, dilation_rate=2, name_prefix='conv4_block4')
134
+ x = bottleneck_residual_block(x, 256, dilation_rate=2, name_prefix='conv4_block5')
135
+ x = bottleneck_residual_block(x, 256, dilation_rate=2, name_prefix='conv4_block6')
136
+
137
+ # Stage 4 (conv5_x) - With higher atrous rate
138
+ x = bottleneck_residual_block(x, 512, strides=1, dilation_rate=4, projection_shortcut=True, name_prefix='conv5_block1')
139
+ x = bottleneck_residual_block(x, 512, dilation_rate=4, name_prefix='conv5_block2')
140
+ x = bottleneck_residual_block(x, 512, dilation_rate=4, name_prefix='conv5_block3')
141
+
142
+ # ==================== ASPP MODULE ====================
143
+ x = aspp_block(x, filters=256)
144
+
145
+ # ==================== DECODER ====================
146
+
147
+ # Use fixed upsampling - the spatial relationship should be predictable
148
+ # ASPP output is at 1/16 resolution, low_level_features at 1/4 resolution
149
+ # So we need 4x upsampling to match
150
+ x = layers.UpSampling2D(size=(4, 4), interpolation='bilinear', name='decoder_upsample1')(x)
151
+
152
+ # Process low-level features
153
+ low_level_features = layers.Conv2D(48, 1, use_bias=False, name='decoder_low_level_conv')(low_level_features)
154
+ # low_level_features = layers.BatchNormalization(name='decoder_low_level_bn')(low_level_features)
155
+ low_level_features = layers.Activation('relu')(low_level_features)
156
+
157
+ # If there's still a size mismatch, crop or pad to match
158
+ def match_spatial_dims(tensors):
159
+ high_level, low_level = tensors
160
+ # Get shapes
161
+ high_shape = tf.shape(high_level)
162
+ low_shape = tf.shape(low_level)
163
+
164
+ # Crop high_level to match low_level if it's larger
165
+ high_level_matched = high_level[:, :low_shape[1], :low_shape[2], :]
166
+ return high_level_matched, low_level
167
+
168
+ x_matched, low_level_matched = layers.Lambda(match_spatial_dims, name='match_dims')([x, low_level_features])
169
+
170
+ # Concatenate high-level and low-level features
171
+ x = layers.Concatenate(name='decoder_concat')([x_matched, low_level_matched])
172
+
173
+ # Refine features
174
+ x = layers.Conv2D(256, 3, padding='same', use_bias=False, name='decoder_conv1')(x)
175
+ # x = layers.BatchNormalization(name='decoder_conv1_bn')(x)
176
+ x = layers.Activation('relu')(x)
177
+ x = layers.Dropout(0.1, name='decoder_dropout1')(x) # Light regularization
178
+
179
+ x = layers.Conv2D(256, 3, padding='same', use_bias=False, name='decoder_conv2')(x)
180
+ # x = layers.BatchNormalization(name='decoder_conv2_bn')(x)
181
+ x = layers.Activation('relu')(x)
182
+ x = layers.Dropout(0.1, name='decoder_dropout2')(x)
183
+
184
+ # Final upsampling to original resolution (4x upsampling)
185
+ x = layers.UpSampling2D(size=(4, 4), interpolation='bilinear', name='decoder_upsample2')(x)
186
+
187
+ # ==================== OUTPUT ====================
188
+
189
+ # Output layer - preserving your original conditional logic
190
+ if num_classes == 1:
191
+ outputs = layers.Conv2D(1, 1, activation='sigmoid', name='output')(x)
192
+ else:
193
+ outputs = layers.Conv2D(num_classes, 1, activation='softmax', name='output')(x)
194
+
195
+ # Create model
196
+ model = keras.Model(inputs, outputs, name='DeepLabV3Plus_ResNet50')
197
+
198
+ return model
models/for_WMH_Vent/model_training_scripts/dlv3_unet_model_GN.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###################### Libraries ######################
2
+ # Deep Learning
3
+ import tensorflow as tf
4
+ import keras
5
+ from keras.models import Model, load_model
6
+ from keras.layers import Input, Conv2D, MaxPooling2D, Conv2DTranspose, concatenate
7
+ from keras import backend as K
8
+ from tensorflow.keras import layers, optimizers, callbacks
9
+ from keras.utils import to_categorical
10
+
11
+
12
+ def build_deeplabv3_unet_3class(input_shape=(256, 256, 1), num_classes=3):
13
+ """
14
+ DeepLabV3+ with ResNet-50 backbone.
15
+
16
+ Key fix over the original:
17
+ - All BatchNormalization replaced with GroupNormalization (groups=8).
18
+ GroupNorm is batch-size independent, so inference statistics are
19
+ identical whether training=True or training=False — no more need to
20
+ force training=True at inference time.
21
+
22
+ Input: single-channel (grayscale) MRI images → (H, W, 1)
23
+ Output: per-pixel class probabilities → (H, W, num_classes)
24
+ or binary mask → (H, W, 1) when num_classes==1
25
+
26
+ Reference:
27
+ "Encoder-Decoder with Atrous Separable Convolution for
28
+ Semantic Image Segmentation", Chen et al. 2018.
29
+ """
30
+
31
+ # ------------------------------------------------------------------
32
+ # Helper: GroupNorm drop-in for BatchNorm
33
+ # groups=8 works well for filter counts ≥ 32 that are multiples of 8.
34
+ # ------------------------------------------------------------------
35
+ def group_norm(name=None):
36
+ return layers.GroupNormalization(groups=4, name=name)
37
+
38
+ # ------------------------------------------------------------------
39
+ def conv_block(x, filters, kernel_size=3, strides=1,
40
+ dilation_rate=1, use_bias=False, name=None):
41
+ """Standard convolution block with GroupNorm and ReLU."""
42
+ x = layers.Conv2D(
43
+ filters, kernel_size, strides=strides, padding='same',
44
+ dilation_rate=dilation_rate, use_bias=use_bias, name=name
45
+ )(x)
46
+ x = group_norm()(x)
47
+ x = layers.Activation('relu')(x)
48
+ return x
49
+
50
+ # ------------------------------------------------------------------
51
+ def bottleneck_residual_block(x, filters, strides=1, dilation_rate=1,
52
+ projection_shortcut=False, name_prefix=""):
53
+ """ResNet-50 bottleneck block with optional atrous convolution."""
54
+ shortcut = x
55
+
56
+ # Projection shortcut if dimensions change
57
+ if projection_shortcut:
58
+ shortcut = layers.Conv2D(
59
+ filters * 4, 1, strides=strides, use_bias=False,
60
+ name=f"{name_prefix}_0_conv"
61
+ )(shortcut)
62
+ shortcut = group_norm(name=f"{name_prefix}_0_gn")(shortcut)
63
+
64
+ # 1×1 → 3×3 (possibly atrous) → 1×1 bottleneck
65
+ x = layers.Conv2D(filters, 1, use_bias=False,
66
+ name=f"{name_prefix}_1_conv")(x)
67
+ x = group_norm(name=f"{name_prefix}_1_gn")(x)
68
+ x = layers.Activation('relu')(x)
69
+
70
+ x = layers.Conv2D(
71
+ filters, 3, strides=strides, padding='same',
72
+ dilation_rate=dilation_rate, use_bias=False,
73
+ name=f"{name_prefix}_2_conv"
74
+ )(x)
75
+ x = group_norm(name=f"{name_prefix}_2_gn")(x)
76
+ x = layers.Activation('relu')(x)
77
+
78
+ x = layers.Conv2D(filters * 4, 1, use_bias=False,
79
+ name=f"{name_prefix}_3_conv")(x)
80
+ x = group_norm(name=f"{name_prefix}_3_gn")(x)
81
+
82
+ x = layers.Add()([shortcut, x])
83
+ x = layers.Activation('relu')(x)
84
+ return x
85
+
86
+ # ------------------------------------------------------------------
87
+ def aspp_block(x, filters=256):
88
+ """Atrous Spatial Pyramid Pooling."""
89
+
90
+ # Branch 1 — 1×1 conv
91
+ b1 = layers.Conv2D(filters, 1, use_bias=False, name='aspp_1x1')(x)
92
+ b1 = group_norm(name='aspp_1x1_gn')(b1)
93
+ b1 = layers.Activation('relu')(b1)
94
+
95
+ # Branch 2 — 3×3, rate=6
96
+ b2 = layers.Conv2D(filters, 3, padding='same', dilation_rate=6,
97
+ use_bias=False, name='aspp_3x3_6')(x)
98
+ b2 = group_norm(name='aspp_3x3_6_gn')(b2)
99
+ b2 = layers.Activation('relu')(b2)
100
+
101
+ # Branch 3 — 3×3, rate=12
102
+ b3 = layers.Conv2D(filters, 3, padding='same', dilation_rate=12,
103
+ use_bias=False, name='aspp_3x3_12')(x)
104
+ b3 = group_norm(name='aspp_3x3_12_gn')(b3)
105
+ b3 = layers.Activation('relu')(b3)
106
+
107
+ # Branch 4 — 3×3, rate=18
108
+ b4 = layers.Conv2D(filters, 3, padding='same', dilation_rate=18,
109
+ use_bias=False, name='aspp_3x3_18')(x)
110
+ b4 = group_norm(name='aspp_3x3_18_gn')(b4)
111
+ b4 = layers.Activation('relu')(b4)
112
+
113
+ # Branch 5 — image-level global context via GAP + resize
114
+ input_shape_dyn = tf.shape(x)
115
+ h, w = input_shape_dyn[1], input_shape_dyn[2]
116
+
117
+ b5 = layers.GlobalAveragePooling2D(name='aspp_gap')(x)
118
+ b5 = layers.Reshape((1, 1, -1))(b5)
119
+ b5 = layers.Conv2D(filters, 1, use_bias=False,
120
+ name='aspp_gap_conv')(b5)
121
+ b5 = group_norm(name='aspp_gap_gn')(b5)
122
+ b5 = layers.Activation('relu')(b5)
123
+ b5 = layers.Lambda(
124
+ lambda args: tf.image.resize(args[0], args[1], method='bilinear'),
125
+ name='aspp_gap_resize'
126
+ )([b5, [h, w]])
127
+
128
+ # Fuse all branches
129
+ concat = layers.Concatenate(name='aspp_concat')([b1, b2, b3, b4, b5])
130
+ out = layers.Conv2D(filters, 1, use_bias=False,
131
+ name='aspp_final_conv')(concat)
132
+ out = group_norm(name='aspp_final_gn')(out)
133
+ out = layers.Activation('relu')(out)
134
+ out = layers.Dropout(0.1, name='aspp_dropout')(out)
135
+ return out
136
+
137
+ # ==================================================================
138
+ # INPUT — grayscale, single channel
139
+ # ==================================================================
140
+ inputs = layers.Input(input_shape, name='input') # (H, W, 1)
141
+
142
+ # ==================================================================
143
+ # ENCODER — ResNet-50 backbone
144
+ # ==================================================================
145
+
146
+ # Stem
147
+ x = layers.Conv2D(64, 7, strides=2, padding='same',
148
+ use_bias=False, name='conv1')(inputs)
149
+ x = group_norm(name='conv1_gn')(x)
150
+ x = layers.Activation('relu')(x)
151
+ x = layers.MaxPooling2D(3, strides=2, padding='same', name='pool1')(x)
152
+
153
+ # Stage 1 — conv2_x (output stride 4 → low-level features for decoder)
154
+ x = bottleneck_residual_block(x, 64, strides=1,
155
+ projection_shortcut=True,
156
+ name_prefix='conv2_block1')
157
+ x = bottleneck_residual_block(x, 64, name_prefix='conv2_block2')
158
+ low_level_features = bottleneck_residual_block(x, 64,
159
+ name_prefix='conv2_block3')
160
+
161
+ # Stage 2 — conv3_x (output stride 8)
162
+ x = bottleneck_residual_block(low_level_features, 128, strides=2,
163
+ projection_shortcut=True,
164
+ name_prefix='conv3_block1')
165
+ x = bottleneck_residual_block(x, 128, name_prefix='conv3_block2')
166
+ x = bottleneck_residual_block(x, 128, name_prefix='conv3_block3')
167
+ x = bottleneck_residual_block(x, 128, name_prefix='conv3_block4')
168
+
169
+ # Stage 3 — conv4_x (atrous rate=2, keeps stride at 8)
170
+ x = bottleneck_residual_block(x, 256, strides=1, dilation_rate=2,
171
+ projection_shortcut=True,
172
+ name_prefix='conv4_block1')
173
+ for i in range(2, 7):
174
+ x = bottleneck_residual_block(x, 256, dilation_rate=2,
175
+ name_prefix=f'conv4_block{i}')
176
+
177
+ # Stage 4 — conv5_x (atrous rate=4, keeps stride at 8)
178
+ x = bottleneck_residual_block(x, 512, strides=1, dilation_rate=4,
179
+ projection_shortcut=True,
180
+ name_prefix='conv5_block1')
181
+ x = bottleneck_residual_block(x, 512, dilation_rate=4,
182
+ name_prefix='conv5_block2')
183
+ x = bottleneck_residual_block(x, 512, dilation_rate=4,
184
+ name_prefix='conv5_block3')
185
+
186
+ # ==================================================================
187
+ # ASPP MODULE
188
+ # ==================================================================
189
+ x = aspp_block(x, filters=256)
190
+
191
+ # ==================================================================
192
+ # DECODER
193
+ # ==================================================================
194
+
195
+ # 4× upsample to reach low-level feature resolution (output stride 4)
196
+ x = layers.UpSampling2D(size=(4, 4), interpolation='bilinear',
197
+ name='decoder_upsample1')(x)
198
+
199
+ # Reduce low-level feature channels to 48 (as in the original paper)
200
+ low_level_features = layers.Conv2D(
201
+ 48, 1, use_bias=False, name='decoder_low_level_conv'
202
+ )(low_level_features)
203
+ low_level_features = group_norm(name='decoder_low_level_gn')(low_level_features)
204
+ low_level_features = layers.Activation('relu')(low_level_features)
205
+
206
+ # Align spatial dims in case of any off-by-one from pooling
207
+ def match_spatial_dims(tensors):
208
+ high_level, low_level = tensors
209
+ low_shape = tf.shape(low_level)
210
+ return high_level[:, :low_shape[1], :low_shape[2], :], low_level
211
+
212
+ x_matched, low_matched = layers.Lambda(
213
+ match_spatial_dims, name='match_dims'
214
+ )([x, low_level_features])
215
+
216
+ # Fuse high-level and low-level features
217
+ x = layers.Concatenate(name='decoder_concat')([x_matched, low_matched])
218
+
219
+ x = layers.Conv2D(256, 3, padding='same', use_bias=False,
220
+ name='decoder_conv1')(x)
221
+ x = group_norm(name='decoder_conv1_gn')(x)
222
+ x = layers.Activation('relu')(x)
223
+ x = layers.Dropout(0.1, name='decoder_dropout1')(x)
224
+
225
+ x = layers.Conv2D(256, 3, padding='same', use_bias=False,
226
+ name='decoder_conv2')(x)
227
+ x = group_norm(name='decoder_conv2_gn')(x)
228
+ x = layers.Activation('relu')(x)
229
+ x = layers.Dropout(0.1, name='decoder_dropout2')(x)
230
+
231
+ # Final 4× upsample back to original resolution
232
+ x = layers.UpSampling2D(size=(4, 4), interpolation='bilinear',
233
+ name='decoder_upsample2')(x)
234
+
235
+ # ==================================================================
236
+ # OUTPUT
237
+ # ==================================================================
238
+ if num_classes == 1:
239
+ # Binary segmentation → sigmoid, single-channel mask
240
+ outputs = layers.Conv2D(1, 1, activation='sigmoid', name='output')(x)
241
+ else:
242
+ # Multi-class segmentation → softmax over num_classes channels
243
+ outputs = layers.Conv2D(num_classes, 1, activation='softmax',
244
+ name='output')(x)
245
+
246
+ model = keras.Model(inputs, outputs, name='DeepLabV3Plus_ResNet50_GN')
247
+ return model
models/for_WMH_Vent/model_training_scripts/p4_compute_class_weights.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ P4 - Utility script to calculate inverse frequency weights for class balancing
3
+
4
+ Usage:
5
+ python p4_compute_class_weights.py --fold 0 --scenario 4class --preprocessing standard
6
+
7
+ Output:
8
+ Saves class weights to JSON file for reproducibility
9
+ Prints weights for use in training
10
+
11
+ Authors:
12
+ "Mahdi Bashiri Bawil, Mousa Shamsi, Abolhassan Shakeri Bavil"
13
+
14
+ Developer:
15
+ "Mahdi Bashiri Bawil"
16
+ """
17
+
18
+ import numpy as np
19
+ import json
20
+ from pathlib import Path
21
+ from tqdm import tqdm
22
+ import argparse
23
+
24
+ # Import data loader
25
+ from p4_data_loader import DataConfig, P2DataLoader
26
+
27
+
28
+ def compute_class_frequencies(dataset, num_classes, total_samples=None):
29
+ """
30
+ Compute class frequencies from dataset
31
+
32
+ Args:
33
+ dataset: TensorFlow dataset yielding (paired_input, target_mask)
34
+ num_classes: Number of classes (3 or 4)
35
+ total_samples: Total number of samples (for progress bar)
36
+
37
+ Returns:
38
+ class_pixel_counts: Array of pixel counts per class
39
+ total_pixels: Total number of pixels analyzed
40
+ """
41
+ class_pixel_counts = np.zeros(num_classes, dtype=np.int64)
42
+ total_pixels = 0
43
+
44
+ print(f"Computing class frequencies for {num_classes}-class scenario...")
45
+
46
+ iterator = tqdm(dataset, total=total_samples, desc="Processing") if total_samples else dataset
47
+
48
+ for paired_input, target_mask, _, _ in iterator:
49
+ # target_mask shape: (batch_size, 256, 256)
50
+ masks = target_mask.numpy()
51
+
52
+ for mask in masks:
53
+ # Count pixels for each class
54
+ for class_id in range(num_classes):
55
+ class_pixel_counts[class_id] += np.sum(mask == class_id)
56
+
57
+ total_pixels += mask.size
58
+
59
+ return class_pixel_counts, total_pixels
60
+
61
+
62
+ def compute_inverse_frequency_weights(class_pixel_counts, num_classes):
63
+ """
64
+ Compute inverse frequency weights with normalization
65
+
66
+ Args:
67
+ class_pixel_counts: Array of pixel counts per class
68
+ num_classes: Number of classes
69
+
70
+ Returns:
71
+ class_weights: Normalized inverse frequency weights
72
+ class_frequencies: Class frequencies (for reference)
73
+ """
74
+ total_pixels = np.sum(class_pixel_counts)
75
+
76
+ # Class frequencies
77
+ class_frequencies = class_pixel_counts / total_pixels
78
+
79
+ # Inverse frequency (with small epsilon to avoid division by zero)
80
+ epsilon = 1e-6
81
+ inverse_freq = 1.0 / (class_frequencies + epsilon)
82
+
83
+ # Normalize weights to sum = num_classes
84
+ # This keeps weights in a reasonable range while maintaining relative importance
85
+ class_weights = inverse_freq / np.sum(inverse_freq) * num_classes
86
+
87
+ return class_weights, class_frequencies
88
+
89
+
90
+ def compute_and_save_class_weights(fold_id, class_scenario, preprocessing,
91
+ output_dir='class_weights'):
92
+ """
93
+ Compute class weights for a specific fold and scenario
94
+
95
+ Args:
96
+ fold_id: Fold number (0-4)
97
+ class_scenario: '3class' or '4class'
98
+ preprocessing: 'standard' or 'zoomed'
99
+ output_dir: Directory to save weights
100
+
101
+ Returns:
102
+ Dictionary with weights and statistics
103
+ """
104
+ print("\n" + "="*70)
105
+ print(f"COMPUTING CLASS WEIGHTS")
106
+ print("="*70)
107
+ print(f"Fold: {fold_id}")
108
+ print(f"Scenario: {class_scenario}")
109
+ print(f"Preprocessing: {preprocessing}")
110
+ print("="*70 + "\n")
111
+
112
+ # Initialize data loader
113
+ config = DataConfig()
114
+ data_loader = P2DataLoader(config)
115
+
116
+ # Determine number of classes
117
+ num_classes = 3 if class_scenario == '3class' else 4
118
+
119
+ # Load training dataset
120
+ print("Loading training dataset...")
121
+ train_dataset = data_loader.create_dataset_for_fold(
122
+ fold_id=fold_id,
123
+ split='train',
124
+ preprocessing=preprocessing,
125
+ class_scenario=class_scenario,
126
+ batch_size=8, # Larger batch for faster processing
127
+ shuffle=False # No need to shuffle for counting
128
+ )
129
+
130
+ # Get dataset size
131
+ train_size = sum(1 for _ in train_dataset)
132
+ print(f"Training samples: {train_size}")
133
+
134
+ # Recreate dataset after consuming
135
+ train_dataset = data_loader.create_dataset_for_fold(
136
+ fold_id=fold_id,
137
+ split='train',
138
+ preprocessing=preprocessing,
139
+ class_scenario=class_scenario,
140
+ batch_size=8,
141
+ shuffle=False
142
+ )
143
+
144
+ # Compute class frequencies
145
+ class_pixel_counts, total_pixels = compute_class_frequencies(
146
+ train_dataset, num_classes, train_size
147
+ )
148
+
149
+ # Compute inverse frequency weights
150
+ class_weights, class_frequencies = compute_inverse_frequency_weights(
151
+ class_pixel_counts, num_classes
152
+ )
153
+
154
+ # Print results
155
+ print("\n" + "="*70)
156
+ print("RESULTS")
157
+ print("="*70)
158
+
159
+ class_names = {
160
+ 3: ['Background', 'Ventricles', 'Abnormal WMH'],
161
+ 4: ['Background', 'Ventricles', 'Normal WMH', 'Abnormal WMH']
162
+ }
163
+
164
+ print(f"\nTotal pixels analyzed: {total_pixels:,}")
165
+ print(f"\nClass Statistics:")
166
+ print("-" * 70)
167
+
168
+ for i in range(num_classes):
169
+ print(f"Class {i} ({class_names[num_classes][i]}):")
170
+ print(f" Pixel count: {class_pixel_counts[i]:,}")
171
+ print(f" Frequency: {class_frequencies[i]:.6f} ({class_frequencies[i]*100:.2f}%)")
172
+ print(f" Weight: {class_weights[i]:.4f}")
173
+ print()
174
+
175
+ # Save to JSON
176
+ output_path = Path(output_dir)
177
+ output_path.mkdir(exist_ok=True)
178
+
179
+ results = {
180
+ 'fold_id': fold_id,
181
+ 'class_scenario': class_scenario,
182
+ 'preprocessing': preprocessing,
183
+ 'num_classes': num_classes,
184
+ 'total_pixels': int(total_pixels),
185
+ 'class_pixel_counts': class_pixel_counts.tolist(),
186
+ 'class_frequencies': class_frequencies.tolist(),
187
+ 'class_weights': class_weights.tolist(),
188
+ 'class_names': class_names[num_classes]
189
+ }
190
+
191
+ filename = f"class_weights_fold{fold_id}_{preprocessing}_{class_scenario}.json"
192
+ filepath = output_path / filename
193
+
194
+ with open(filepath, 'w') as f:
195
+ json.dump(results, f, indent=2)
196
+
197
+ print("="*70)
198
+ print(f"✅ Class weights saved to: {filepath}")
199
+ print("="*70)
200
+
201
+ # Print weights in format ready for code
202
+ print("\nFor use in training script:")
203
+ print("-" * 70)
204
+ print(f"class_weights = tf.constant({class_weights.tolist()}, dtype=tf.float32)")
205
+ print()
206
+
207
+ return results
208
+
209
+
210
+ def compute_all_scenarios_for_fold(fold_id):
211
+ """
212
+ Compute class weights for all 4 scenarios of a given fold
213
+
214
+ Args:
215
+ fold_id: Fold number (0-4)
216
+ """
217
+ scenarios = [
218
+ {'preprocessing': 'standard', 'class_scenario': '3class'},
219
+ {'preprocessing': 'standard', 'class_scenario': '4class'},
220
+ {'preprocessing': 'zoomed', 'class_scenario': '3class'},
221
+ {'preprocessing': 'zoomed', 'class_scenario': '4class'},
222
+ ]
223
+
224
+ all_results = {}
225
+
226
+ for scenario in scenarios:
227
+ results = compute_and_save_class_weights(
228
+ fold_id=fold_id,
229
+ class_scenario=scenario['class_scenario'],
230
+ preprocessing=scenario['preprocessing']
231
+ )
232
+
233
+ key = f"{scenario['preprocessing']}_{scenario['class_scenario']}"
234
+ all_results[key] = results
235
+
236
+ print("\n" + "="*70 + "\n")
237
+
238
+ return all_results
239
+
240
+
241
+ def load_class_weights(fold_id, class_scenario, preprocessing, weights_dir='class_weights'):
242
+ """
243
+ Load previously computed class weights
244
+
245
+ Args:
246
+ fold_id: Fold number (0-4)
247
+ class_scenario: '3class' or '4class'
248
+ preprocessing: 'standard' or 'zoomed'
249
+ weights_dir: Directory containing weights files
250
+
251
+ Returns:
252
+ class_weights: NumPy array of weights
253
+ """
254
+ weights_path = Path(weights_dir)
255
+ filename = f"class_weights_fold{fold_id}_{preprocessing}_{class_scenario}.json"
256
+ filepath = weights_path / filename
257
+
258
+ if not filepath.exists():
259
+ raise FileNotFoundError(
260
+ f"Class weights not found: {filepath}\n"
261
+ f"Run compute_and_save_class_weights() first."
262
+ )
263
+
264
+ with open(filepath, 'r') as f:
265
+ results = json.load(f)
266
+
267
+ class_weights = np.array(results['class_weights'], dtype=np.float32)
268
+
269
+ return class_weights
270
+
271
+
272
+ def main():
273
+ """Main entry point with argument parsing"""
274
+ parser = argparse.ArgumentParser(
275
+ description='Compute class weights from training data',
276
+ formatter_class=argparse.RawDescriptionHelpFormatter,
277
+ epilog="""
278
+ Examples:
279
+ # Single scenario
280
+ python p2_compute_class_weights.py --fold 0 --scenario 4class --preprocessing standard
281
+
282
+ # All scenarios for one fold
283
+ python p2_compute_class_weights.py --fold 0 --all
284
+
285
+ # All folds (for completeness)
286
+ python p2_compute_class_weights.py --all-folds
287
+ """
288
+ )
289
+
290
+ parser.add_argument(
291
+ '--fold',
292
+ type=int,
293
+ choices=[0, 1, 2, 3, 4],
294
+ help='Fold number (0-4)'
295
+ )
296
+
297
+ parser.add_argument(
298
+ '--scenario',
299
+ type=str,
300
+ choices=['3class', '4class'],
301
+ help='Class scenario'
302
+ )
303
+
304
+ parser.add_argument(
305
+ '--preprocessing',
306
+ type=str,
307
+ choices=['standard', 'zoomed'],
308
+ help='Preprocessing type'
309
+ )
310
+
311
+ parser.add_argument(
312
+ '--all',
313
+ action='store_true',
314
+ help='Compute for all scenarios of specified fold'
315
+ )
316
+
317
+ parser.add_argument(
318
+ '--all-folds',
319
+ action='store_true',
320
+ help='Compute for all scenarios of all folds'
321
+ )
322
+
323
+ args = parser.parse_args()
324
+
325
+ # Validate arguments
326
+ if args.all_folds:
327
+ # Compute for all folds
328
+ for fold_id in range(5):
329
+ print(f"\n{'='*70}")
330
+ print(f"PROCESSING FOLD {fold_id}")
331
+ print(f"{'='*70}\n")
332
+ compute_all_scenarios_for_fold(fold_id)
333
+
334
+ elif args.all:
335
+ # Compute all scenarios for one fold
336
+ if args.fold is None:
337
+ parser.error("--fold is required when using --all")
338
+ compute_all_scenarios_for_fold(args.fold)
339
+
340
+ else:
341
+ # Compute single scenario
342
+ if args.fold is None or args.scenario is None or args.preprocessing is None:
343
+ parser.error("--fold, --scenario, and --preprocessing are required")
344
+
345
+ compute_and_save_class_weights(
346
+ fold_id=args.fold,
347
+ class_scenario=args.scenario,
348
+ preprocessing=args.preprocessing
349
+ )
350
+
351
+
352
+ if __name__ == "__main__":
353
+ main()
models/for_WMH_Vent/model_training_scripts/p4_data_loader.py ADDED
@@ -0,0 +1,912 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ P4 Article - Data Loading System
3
+
4
+ Complete implementation for brain segmentation experiments
5
+
6
+ WMH and Ventricles Segmentation with U-Net Models - Journal Paper Implementation
7
+ Three-class segmentation: Background vs Ventricles vs Abnormal WMH
8
+ Professional results saving and visualization for publication
9
+
10
+ This relates to our article:
11
+ "Deep Learning-Based Neuroanatomical Profiling Reveals Detailed Brain Changes:
12
+ A Large-Scale Multiple Sclerosis Study"
13
+
14
+ Features:
15
+ - Load FLAIR images and individual mask files from Cohort directory
16
+ - Support both Local_SAI (MS3SEG) and Public_MSSEG (MSSEG2016) datasets
17
+ - Handle standard and zoomed preprocessing variants
18
+ - Combine masks into 3-class or 4-class format
19
+ - Create paired inputs: [FLAIR | mask] concatenated (256x512)
20
+ - Patient-stratified K-fold cross-validation
21
+ - TensorFlow dataset creation with proper batching
22
+
23
+ Authors:
24
+ "Mahdi Bashiri Bawil, Mousa Shamsi, Abolhassan Shakeri Bavil"
25
+
26
+ Developer:
27
+ "Mahdi Bashiri Bawil"
28
+ """
29
+
30
+ import numpy as np
31
+ import os
32
+ from pathlib import Path
33
+ from typing import Tuple, List, Dict, Optional
34
+ import json
35
+ from sklearn.model_selection import KFold
36
+ from tqdm import tqdm
37
+ import cv2 as cv
38
+
39
+ # Deep Learning
40
+ import tensorflow as tf
41
+
42
+
43
+ ###################### Configuration ######################
44
+
45
+ class DataConfig:
46
+ """Data configuration for P4 experiments"""
47
+
48
+ def __init__(self):
49
+ # Base paths
50
+ self.cohort_dir = Path("/mnt/e/MBashiri/ours_articles/Paper#2/Data/Cohort") # CHANGE THIS to your actual path of Data Cohort
51
+
52
+ # Dataset configurations
53
+ self.datasets = {
54
+ 'Local_SAI_updated': {
55
+ 'base_path': self.cohort_dir / 'Local_SAI_updated',
56
+ 'slice_range': (1, 20), # inclusive range 9,15
57
+ 'patient_prefix_length': 6 # "101228"
58
+ },
59
+ 'Public_MSSEG': {
60
+ 'base_path': self.cohort_dir / 'Public_MSSEG',
61
+ 'slice_range': (1, 50), # inclusive range 24,43
62
+ 'patient_prefix_length': 6 # "c01p01"
63
+ }
64
+ }
65
+
66
+ # Preprocessing variants
67
+ self.preprocessing_types = ['standard', 'zoomed']
68
+
69
+ # Class scenarios
70
+ self.class_scenarios = {
71
+ '3class': {
72
+ 'num_classes': 3,
73
+ 'class_names': ['Background', 'Ventricles', 'Abnormal WMH'],
74
+ 'description': 'Three-class: Background, Ventricles, Abnormal WMH',
75
+ 'class_mapping': {
76
+ 'background': 0,
77
+ 'ventricles': 1,
78
+ 'abnormal_wmh': 2,
79
+ }
80
+ },
81
+ '4class': {
82
+ 'num_classes': 4,
83
+ 'class_names': ['Background', 'Ventricles', 'Normal WMH', 'Abnormal WMH'],
84
+ 'description': 'Four-class: Background, Ventricles, Normal WMH, Abnormal WMH',
85
+ 'class_mapping': {
86
+ 'background': 0,
87
+ 'ventricles': 1,
88
+ 'normal_wmh': 2,
89
+ 'abnormal_wmh': 3
90
+ }
91
+ }
92
+ }
93
+
94
+ # K-fold parameters
95
+ self.k_folds = 4
96
+ self.test_split = 0.2 # 20% for test set
97
+ self.random_state = 42
98
+
99
+ # Image parameters
100
+ self.target_size = (256, 256)
101
+ self.paired_width = 512 # FLAIR (256) + mask (256)
102
+
103
+ # Paths for splits
104
+ self.splits_dir = Path("data_splits")
105
+ self.splits_file = self.splits_dir / "concat_fold_assignments.json"
106
+
107
+
108
+ ###################### Helper Functions ######################
109
+
110
+ def extract_patient_id(filename: str, prefix_length: int = 6) -> str:
111
+ """
112
+ Extract patient ID from filename
113
+
114
+ Args:
115
+ filename: e.g., "101228_5.npy" or "c01p01_25.png"
116
+ prefix_length: Number of characters in patient ID
117
+
118
+ Returns:
119
+ Patient ID: e.g., "101228" or "c01p01"
120
+ """
121
+ return filename.split('_')[0][:prefix_length]
122
+
123
+
124
+ def extract_slice_number(filename: str) -> int:
125
+ """
126
+ Extract slice number from filename
127
+
128
+ Args:
129
+ filename: e.g., "101228_5.npy" or "c01p01_25.png"
130
+
131
+ Returns:
132
+ Slice number as integer
133
+ """
134
+ # Get the part before file extension
135
+ basename = filename.split('.')[0]
136
+ # Get the last part after splitting by '_'
137
+ slice_num = basename.split('_')[-1]
138
+ return int(slice_num)
139
+
140
+
141
+ def load_flair_image(flair_path: Path, normalize: bool = False, of_z_score: bool = False) -> np.ndarray:
142
+ """
143
+ Load FLAIR image (.png format)
144
+
145
+ Args:
146
+ flair_path: Path to .png file
147
+ normalize: Whether to apply z-score normalization
148
+
149
+ Returns:
150
+ FLAIR image (256, 256, 1) as float32
151
+ """
152
+ if of_z_score:
153
+ # Load NPY: the already z-scored FLAIR image data
154
+ flair = np.load(str(flair_path).replace('.png','.npy')).astype(np.float32)
155
+ else:
156
+ # Load PNG as grayscale
157
+ flair = cv.imread(str(flair_path), cv.IMREAD_GRAYSCALE).astype(np.float32)
158
+
159
+ # Normalize to [-1, 1]:
160
+ flair = (flair - np.min(flair)) / (np.max(flair) - np.min(flair))
161
+ flair = (2 * flair) - 1
162
+
163
+ # Ensure correct shape
164
+ if len(flair.shape) == 2:
165
+ flair = np.expand_dims(flair, axis=-1)
166
+
167
+ # Additional normalization if needed (should already be normalized)
168
+ if normalize and (np.std(flair) > 2.0 or np.abs(np.mean(flair)) > 1.0):
169
+ # Re-normalize if values seem off
170
+ flair = (flair - np.mean(flair)) / (np.std(flair) + 1e-7)
171
+
172
+ return flair
173
+
174
+
175
+ def load_mask_image(mask_path: Path) -> np.ndarray:
176
+ """
177
+ Load mask image (.png format)
178
+
179
+ Args:
180
+ mask_path: Path to .png file
181
+
182
+ Returns:
183
+ Binary mask (256, 256) as uint8
184
+ """
185
+ # Load PNG as grayscale
186
+ mask = cv.imread(str(mask_path), cv.IMREAD_GRAYSCALE)
187
+
188
+ if mask is None:
189
+ raise FileNotFoundError(f"Could not load mask: {mask_path}")
190
+
191
+ # Binarize (any non-zero value becomes 1)
192
+ mask = (mask > 0).astype(np.uint8)
193
+
194
+ return mask
195
+
196
+
197
+ def combine_masks(vent_mask: np.ndarray,
198
+ nwmh_mask: np.ndarray,
199
+ abwmh_mask: np.ndarray,
200
+ class_scenario: str,
201
+ preprocess: bool = False) -> np.ndarray:
202
+ """
203
+ Combine individual masks into multi-class format
204
+
205
+ Args:
206
+ vent_mask: Ventricles mask (256, 256)
207
+ nwmh_mask: Normal WMH mask (256, 256)
208
+ abwmh_mask: Abnormal WMH mask (256, 256)
209
+ class_scenario: '3class' or '4class'
210
+ preprocess: Boolean turning the morphological preprocessing on or off
211
+
212
+ Returns:
213
+ Combined mask (256, 256) with class labels
214
+ """
215
+ if preprocess:
216
+ from skimage.morphology import remove_small_objects, binary_erosion, binary_closing, binary_opening, disk, binary_dilation
217
+ min_object_size = 5
218
+ closing_kernel_size = 2
219
+ dilation_kernel_size = 1
220
+
221
+ vent_mask = vent_mask > 0
222
+ abwmh_mask = abwmh_mask > 0
223
+ nwmh_mask = nwmh_mask > 0
224
+
225
+ abwmh_mask = binary_closing(abwmh_mask, disk(closing_kernel_size))
226
+ abwmh_mask = binary_erosion(abwmh_mask, disk(dilation_kernel_size))
227
+ abwmh_mask = remove_small_objects(abwmh_mask, min_size=min_object_size)
228
+
229
+ nwmh_mask = binary_closing(nwmh_mask, disk(closing_kernel_size))
230
+ nwmh_mask = binary_erosion(nwmh_mask, disk(dilation_kernel_size))
231
+ nwmh_mask = remove_small_objects(nwmh_mask, min_size=min_object_size)
232
+
233
+ vent_mask = binary_closing(vent_mask, disk(closing_kernel_size))
234
+ vent_mask = binary_erosion(vent_mask, disk(dilation_kernel_size))
235
+ vent_mask = remove_small_objects(vent_mask, min_size=min_object_size)
236
+
237
+ abwmh_mask = abwmh_mask & ~vent_mask
238
+ nwmh_mask = nwmh_mask & ~vent_mask
239
+ abwmh_mask = abwmh_mask & ~nwmh_mask
240
+
241
+ if class_scenario == '3class':
242
+ # Class 0: Background (default)
243
+ # Class 1: Ventricles
244
+ # Class 2: Abnormal WMH
245
+ combined = np.zeros_like(vent_mask, dtype=np.uint8)
246
+ combined[vent_mask>0] = 1
247
+ combined[abwmh_mask>0] = 2
248
+
249
+ elif class_scenario == '4class':
250
+ # Class 0: Background (default)
251
+ # Class 1: Ventricles
252
+ # Class 2: Normal WMH
253
+ # Class 3: Abnormal WMH
254
+ combined = np.zeros_like(vent_mask, dtype=np.uint8)
255
+ combined[vent_mask>0] = 1
256
+ combined[nwmh_mask>0] = 2
257
+ combined[abwmh_mask>0] = 3
258
+
259
+ else:
260
+ raise ValueError(f"Unknown class_scenario: {class_scenario}")
261
+
262
+ return combined
263
+
264
+
265
+ def is_valid_slice(vent_mask: np.ndarray,
266
+ nwmh_mask: np.ndarray,
267
+ abwmh_mask: np.ndarray) -> bool:
268
+ """
269
+ Check if slice has at least one non-empty mask
270
+
271
+ Args:
272
+ vent_mask: Ventricles mask (256, 256)
273
+ nwmh_mask: Normal WMH mask (256, 256)
274
+ abwmh_mask: Abnormal WMH mask (256, 256)
275
+
276
+ Returns:
277
+ True if at least one mask has non-zero pixels
278
+ """
279
+ has_ventricles = np.sum(vent_mask) > 50
280
+ has_nwmh = np.sum(nwmh_mask) > 50
281
+ has_abwmh = np.sum(abwmh_mask) > 50
282
+
283
+ # Valid if ANY mask has content
284
+ return True # or has_nwmh has_ventricles or has_abwmh #
285
+
286
+
287
+ def create_paired_input(flair: np.ndarray,
288
+ mask: np.ndarray,
289
+ brain_mask: np.ndarray,
290
+ num_classes: np.ndarray,
291
+ if_bet=False) -> np.ndarray:
292
+ """
293
+ Create paired input: [FLAIR | mask] concatenated horizontally
294
+
295
+ Args:
296
+ flair: FLAIR image (256, 256, 1) float32
297
+ mask: Combined mask (256, 256) uint8
298
+
299
+ Returns:
300
+ Paired image (256, 512, 1) float32
301
+ """
302
+ # Binarize (any non-zero value becomes 1)
303
+ brain_mask = brain_mask > 0
304
+
305
+ # Brain extraction
306
+ if if_bet:
307
+ # print("\n\t Doing THEEEEEEEEE BET")
308
+ flair[~brain_mask] = np.min(flair)
309
+ mask[~brain_mask] = 0
310
+
311
+ # Ensure flair is 3D
312
+ if len(flair.shape) == 2:
313
+ flair = np.expand_dims(flair, axis=-1)
314
+
315
+ # Convert mask to float and normalize to [0, 1] range for consistency
316
+ # For 3-class: 0, 1, 2 -> -1, 0, 1.0
317
+ # For 4-class: 0, 1, 2, 3 -> -1, -0.333, 0.333, 1.0
318
+ max_class = num_classes
319
+ mask_normalized = mask.astype(np.float32)
320
+ if max_class > 0:
321
+ mask_normalized = mask_normalized / max_class
322
+ mask_normalized = (2 * mask_normalized) - 1
323
+
324
+ mask_3d = np.expand_dims(mask_normalized, axis=-1)
325
+
326
+ # Concatenate horizontally: [FLAIR | mask]
327
+ paired = np.concatenate([flair, mask_3d], axis=1) # (256, 512, 1)
328
+
329
+ return paired, mask
330
+
331
+
332
+ ###################### Patient Stratified Splitting ######################
333
+
334
+ class PatientStratifiedSplitter:
335
+ """
336
+ Create patient-stratified train/val/test splits
337
+ Similar to P6 implementation but adapted for P4 data structure
338
+ """
339
+
340
+ def __init__(self, config: DataConfig):
341
+ self.config = config
342
+ self.config.splits_dir.mkdir(exist_ok=True)
343
+
344
+ def collect_all_patients(self) -> Dict[str, List[str]]:
345
+ """
346
+ Collect all unique patient IDs from both datasets
347
+
348
+ Returns:
349
+ Dictionary mapping dataset_name -> list of patient IDs
350
+ """
351
+ all_patients = {}
352
+
353
+ for dataset_name, dataset_config in self.config.datasets.items():
354
+ patients = set()
355
+
356
+ # Path to FLAIR images (standard preprocessing)
357
+ flair_dir = dataset_config['base_path'] / 'FLAIR' / 'Preprocessed' / 'images'
358
+
359
+ if not flair_dir.exists():
360
+ print(f"Warning: {flair_dir} does not exist. Skipping {dataset_name}.")
361
+ continue
362
+
363
+ # Collect all .png files
364
+ for flair_file in flair_dir.glob('*.png'):
365
+ patient_id = extract_patient_id(
366
+ flair_file.name,
367
+ dataset_config['patient_prefix_length']
368
+ )
369
+ patients.add(patient_id)
370
+
371
+ all_patients[dataset_name] = sorted(list(patients))
372
+ print(f"{dataset_name}: {len(all_patients[dataset_name])} patients")
373
+
374
+ return all_patients
375
+
376
+ def create_patient_stratified_splits(self,
377
+ save: bool = True) -> Dict:
378
+ """
379
+ Create patient-stratified K-fold splits
380
+
381
+ Returns:
382
+ Dictionary containing fold assignments
383
+ """
384
+ all_patients = self.collect_all_patients()
385
+
386
+ # Combine patients from both datasets
387
+ combined_patients = []
388
+ for dataset_name, patients in all_patients.items():
389
+ combined_patients.extend(patients)
390
+
391
+ combined_patients = np.array(combined_patients)
392
+ total_patients = len(combined_patients)
393
+
394
+ print(f"\nTotal unique patients: {total_patients}")
395
+
396
+ # Step 1: Split into train+val (80%) and test (20%)
397
+ np.random.seed(self.config.random_state)
398
+ test_size = int(total_patients * self.config.test_split)
399
+
400
+ test_indices = np.random.choice(
401
+ total_patients,
402
+ size=test_size,
403
+ replace=False
404
+ )
405
+
406
+ test_patients = combined_patients[test_indices]
407
+ train_val_indices = np.setdiff1d(np.arange(total_patients), test_indices)
408
+ train_val_patients = combined_patients[train_val_indices]
409
+
410
+ print(f"Test patients: {len(test_patients)}")
411
+ print(f"Train+Val patients: {len(train_val_patients)}")
412
+
413
+ # Step 2: Create K-fold splits on train+val patients
414
+ kfold = KFold(
415
+ n_splits=self.config.k_folds,
416
+ shuffle=True,
417
+ random_state=self.config.random_state
418
+ )
419
+
420
+ fold_assignments = {
421
+ 'metadata': {
422
+ 'total_patients': total_patients,
423
+ 'test_patients': len(test_patients),
424
+ 'trainval_patients': len(train_val_patients),
425
+ 'n_folds': self.config.k_folds,
426
+ 'random_seed': self.config.random_state,
427
+ 'datasets': list(all_patients.keys())
428
+ },
429
+ 'test_set': {
430
+ 'patients': test_patients.tolist(),
431
+ 'n_patients': len(test_patients)
432
+ },
433
+ 'folds': {}
434
+ }
435
+
436
+ for fold_idx, (train_idx, val_idx) in enumerate(kfold.split(train_val_patients)):
437
+ train_patients_fold = train_val_patients[train_idx]
438
+ val_patients_fold = train_val_patients[val_idx]
439
+
440
+ fold_assignments['folds'][f'fold_{fold_idx}'] = {
441
+ 'train_patients': train_patients_fold.tolist(),
442
+ 'val_patients': val_patients_fold.tolist(),
443
+ 'n_train': len(train_patients_fold),
444
+ 'n_val': len(val_patients_fold)
445
+ }
446
+
447
+ print(f"Fold {fold_idx}: Train={len(train_patients_fold)}, Val={len(val_patients_fold)}")
448
+
449
+ # Save to JSON
450
+ if save:
451
+ with open(self.config.splits_file, 'w') as f:
452
+ json.dump(fold_assignments, f, indent=2)
453
+ print(f"\n✅ Fold assignments saved to: {self.config.splits_file}")
454
+
455
+ return fold_assignments
456
+
457
+ def load_fold_assignments(self) -> Dict:
458
+ """Load existing fold assignments from JSON"""
459
+ if not self.config.splits_file.exists():
460
+ raise FileNotFoundError(
461
+ f"Fold assignments not found: {self.config.splits_file}\n"
462
+ f"Run create_patient_stratified_splits() first."
463
+ )
464
+
465
+ with open(self.config.splits_file, 'r') as f:
466
+ fold_assignments = json.load(f)
467
+
468
+ return fold_assignments
469
+
470
+ def verify_patient_separation(self, fold_assignments: Dict) -> bool:
471
+ """
472
+ Verify no patient appears in multiple folds or in both train/val
473
+ Similar to P6's verification logic
474
+ """
475
+ print("\n" + "="*60)
476
+ print("VERIFYING PATIENT SEPARATION")
477
+ print("="*60)
478
+
479
+ all_issues = []
480
+ test_patients = set(fold_assignments['test_set']['patients'])
481
+
482
+ # Check 1: No patient in both test and train/val
483
+ for fold_name, fold_data in fold_assignments['folds'].items():
484
+ train_patients = set(fold_data['train_patients'])
485
+ val_patients = set(fold_data['val_patients'])
486
+
487
+ test_train_overlap = test_patients.intersection(train_patients)
488
+ test_val_overlap = test_patients.intersection(val_patients)
489
+
490
+ if test_train_overlap:
491
+ issue = f"{fold_name}: Test-Train overlap: {test_train_overlap}"
492
+ all_issues.append(issue)
493
+ print(f"❌ {issue}")
494
+
495
+ if test_val_overlap:
496
+ issue = f"{fold_name}: Test-Val overlap: {test_val_overlap}"
497
+ all_issues.append(issue)
498
+ print(f"❌ {issue}")
499
+
500
+ # Check 2: No patient in both train and val within same fold
501
+ for fold_name, fold_data in fold_assignments['folds'].items():
502
+ train_patients = set(fold_data['train_patients'])
503
+ val_patients = set(fold_data['val_patients'])
504
+
505
+ train_val_overlap = train_patients.intersection(val_patients)
506
+ if train_val_overlap:
507
+ issue = f"{fold_name}: Train-Val overlap: {train_val_overlap}"
508
+ all_issues.append(issue)
509
+ print(f"❌ {issue}")
510
+
511
+ # Check 3: Each patient in validation exactly once
512
+ all_val_patients = []
513
+ for fold_data in fold_assignments['folds'].values():
514
+ all_val_patients.extend(fold_data['val_patients'])
515
+
516
+ val_patient_counts = {}
517
+ for patient in all_val_patients:
518
+ val_patient_counts[patient] = val_patient_counts.get(patient, 0) + 1
519
+
520
+ for patient, count in val_patient_counts.items():
521
+ if count != 1:
522
+ issue = f"Patient {patient} in validation {count} times (should be 1)"
523
+ all_issues.append(issue)
524
+ print(f"❌ {issue}")
525
+
526
+ if not all_issues:
527
+ print("✅ All patient separation checks passed")
528
+ print("✅ No data leakage detected")
529
+ return True
530
+ else:
531
+ print(f"\n❌ Found {len(all_issues)} issues")
532
+ return False
533
+
534
+
535
+ ###################### Data Loader ######################
536
+
537
+ class P2DataLoader:
538
+ """
539
+ Main data loader for P2 experiments
540
+ Handles loading FLAIR and masks, creating paired inputs, TensorFlow datasets
541
+ """
542
+
543
+ def __init__(self, config: DataConfig):
544
+ self.config = config
545
+
546
+ def get_file_paths(self,
547
+ patient_id: str,
548
+ slice_num: int,
549
+ dataset_name: str,
550
+ preprocessing: str) -> Dict[str, Path]:
551
+ """
552
+ Construct file paths for a given patient-slice
553
+
554
+ Args:
555
+ patient_id: e.g., "101228" or "c01p01"
556
+ slice_num: Slice number
557
+ dataset_name: 'Local_SAI_updated' or 'Public_MSSEG'
558
+ preprocessing: 'standard' or 'zoomed'
559
+
560
+ Returns:
561
+ Dictionary with paths to FLAIR and mask files
562
+ """
563
+ dataset_config = self.config.datasets[dataset_name]
564
+ base_path = dataset_config['base_path']
565
+
566
+ # Determine subdirectory based on preprocessing
567
+ if preprocessing == 'standard':
568
+ flair_subdir = 'images'
569
+ gt_subdir = 'images'
570
+ else: # zoomed
571
+ flair_subdir = 'zoomed/images'
572
+ gt_subdir = 'zoomed/images'
573
+
574
+ # Construct paths
575
+ flair_path = base_path / 'FLAIR' / 'Preprocessed' / flair_subdir / f'{patient_id}_{slice_num}.png'
576
+ vent_path = base_path / 'GroundTruth' / gt_subdir / 'Vent_Masks' / f'{patient_id}_{slice_num}.png'
577
+ nwmh_path = base_path / 'GroundTruth' / gt_subdir / 'nWMH_Masks' / f'{patient_id}_{slice_num}.png'
578
+ abwmh_path = base_path / 'GroundTruth' / gt_subdir / 'abWMH_Masks' / f'{patient_id}_{slice_num}.png'
579
+ brain_path = base_path / 'GroundTruth' / gt_subdir / 'Brain_Masks' / f'{patient_id}_{slice_num}.png'
580
+
581
+ # Optional: zooming factors (only for zoomed preprocessing)
582
+ zoom_factors_path = None
583
+ if preprocessing == 'zoomed':
584
+ zoom_factors_path = base_path / 'FLAIR' / 'Preprocessed' / 'zoomed' / 'images' / f'{patient_id}_zooming_factors.npy'
585
+
586
+ return {
587
+ 'flair': flair_path,
588
+ 'vent_mask': vent_path,
589
+ 'nwmh_mask': nwmh_path,
590
+ 'abwmh_mask': abwmh_path,
591
+ 'brain_mask': brain_path,
592
+ 'zoom_factors': zoom_factors_path
593
+ }
594
+
595
+ def load_single_slice(self,
596
+ patient_id: str,
597
+ slice_num: int,
598
+ dataset_name: str,
599
+ preprocessing: str,
600
+ class_scenario: str,
601
+ of_z_score: bool = True,
602
+ if_bet: bool = True,
603
+ pre_morph: bool = False) -> Tuple[np.ndarray, np.ndarray]:
604
+ """
605
+ Load a single patient-slice and create paired input
606
+
607
+ Args:
608
+ patient_id: Patient identifier
609
+ slice_num: Slice number
610
+ dataset_name: 'Local_SAI_updated' or 'Public_MSSEG'
611
+ preprocessing: 'standard' or 'zoomed'
612
+ class_scenario: '3class' or '4class'
613
+
614
+ Returns:
615
+ Tuple of (paired_input, combined_mask)
616
+ - paired_input: (256, 512, 1) FLAIR + mask concatenated
617
+ - combined_mask: (256, 256) multi-class labels
618
+ """
619
+ # Class number
620
+ num_classes = int(class_scenario[0]) - 1
621
+
622
+ # Get file paths
623
+ paths = self.get_file_paths(patient_id, slice_num, dataset_name, preprocessing)
624
+
625
+ # Load FLAIR
626
+ flair = load_flair_image(paths['flair'], of_z_score=of_z_score)
627
+
628
+ # Load masks
629
+ vent_mask = load_mask_image(paths['vent_mask'])
630
+ nwmh_mask = load_mask_image(paths['nwmh_mask'])
631
+ abwmh_mask = load_mask_image(paths['abwmh_mask'])
632
+ brain_mask = load_mask_image(paths['brain_mask'])
633
+
634
+ # Combine masks
635
+ combined_mask = combine_masks(vent_mask, nwmh_mask, abwmh_mask, class_scenario, preprocess=pre_morph)
636
+
637
+ # Create paired input
638
+ paired_input, combined_mask = create_paired_input(flair, combined_mask, brain_mask, num_classes=num_classes, if_bet=if_bet)
639
+
640
+ return paired_input, combined_mask
641
+
642
+ def collect_patient_slices(self,
643
+ patient_list: List[str],
644
+ dataset_name: str,
645
+ preprocessing: str) -> List[Tuple[str, int, str]]:
646
+ """
647
+ Collect all valid slice files for given patients
648
+ FILTERS OUT SLICES WITH ALL EMPTY MASKS
649
+
650
+ Args:
651
+ patient_list: List of patient IDs
652
+ dataset_name: 'Local_SAI_updated' or 'Public_MSSEG'
653
+ preprocessing: 'standard' or 'zoomed'
654
+
655
+ Returns:
656
+ List of tuples (patient_id, slice_num, dataset_name)
657
+ """
658
+ dataset_config = self.config.datasets[dataset_name]
659
+ slice_min, slice_max = dataset_config['slice_range']
660
+
661
+ patient_slices = []
662
+ skipped_empty = 0
663
+
664
+ for patient_id in patient_list:
665
+ # Check which dataset this patient belongs to
666
+ # Try to find patient in current dataset
667
+ for slice_num in range(slice_min, slice_max + 1):
668
+ paths = self.get_file_paths(patient_id, slice_num, dataset_name, preprocessing)
669
+
670
+ # Check if all required files exist
671
+ if (paths['flair'].exists() and
672
+ paths['vent_mask'].exists() and
673
+ paths['nwmh_mask'].exists() and
674
+ paths['abwmh_mask'].exists() and
675
+ paths['brain_mask'].exists()):
676
+
677
+ # VALIDATION: Check if masks are not all empty
678
+ try:
679
+ vent_mask = load_mask_image(paths['vent_mask'])
680
+ nwmh_mask = load_mask_image(paths['nwmh_mask'])
681
+ abwmh_mask = load_mask_image(paths['abwmh_mask'])
682
+ brain_mask = load_mask_image(paths['brain_mask'])
683
+
684
+ # Only add if at least one mask has content
685
+ if is_valid_slice(vent_mask, nwmh_mask, abwmh_mask):
686
+ patient_slices.append((patient_id, slice_num, dataset_name))
687
+ else:
688
+ skipped_empty += 1
689
+
690
+ except Exception as e:
691
+ print(f"Warning: Could not validate {patient_id}_{slice_num}: {e}")
692
+ skipped_empty += 1
693
+
694
+ if skipped_empty > 0:
695
+ print(f" ⚠️ Skipped {skipped_empty} slices with empty masks")
696
+
697
+ return patient_slices
698
+
699
+ def create_dataset_for_fold(self,
700
+ fold_id: int,
701
+ split: str,
702
+ preprocessing: str,
703
+ class_scenario: str,
704
+ batch_size: int = 1,
705
+ shuffle: bool = True,
706
+ use_z_scored: bool = True,
707
+ bet: bool = False) -> tf.data.Dataset:
708
+ """
709
+ Create TensorFlow dataset for a specific fold and split
710
+
711
+ Args:
712
+ fold_id: Fold number (0-4)
713
+ split: 'train', 'val', or 'test'
714
+ preprocessing: 'standard' or 'zoomed'
715
+ class_scenario: '3class' or '4class'
716
+ batch_size: Batch size
717
+ shuffle: Whether to shuffle data
718
+
719
+ Returns:
720
+ tf.data.Dataset yielding (paired_input, combined_mask) batches
721
+ """
722
+ # Load fold assignments
723
+ splitter = PatientStratifiedSplitter(self.config)
724
+ fold_assignments = splitter.load_fold_assignments()
725
+
726
+ # Get patient list for this split
727
+ if split == 'test':
728
+ patient_list = fold_assignments['test_set']['patients']
729
+ else:
730
+ fold_key = f'fold_{fold_id}'
731
+ if split == 'train':
732
+ patient_list = fold_assignments['folds'][fold_key]['train_patients']
733
+ elif split == 'val':
734
+ patient_list = fold_assignments['folds'][fold_key]['val_patients']
735
+ else:
736
+ raise ValueError(f"Unknown split: {split}")
737
+
738
+ print(f"\nCreating dataset for fold {fold_id}, split '{split}'")
739
+ print(f"Patients: {len(patient_list)}")
740
+
741
+ # Collect all patient-slices from both datasets
742
+ all_patient_slices = []
743
+
744
+ for dataset_name in self.config.datasets.keys():
745
+ # Filter patient list to only include patients from this dataset
746
+ # This is done by checking patient ID prefix
747
+ dataset_patients = [p for p in patient_list]
748
+
749
+ patient_slices = self.collect_patient_slices(
750
+ dataset_patients,
751
+ dataset_name,
752
+ preprocessing
753
+ )
754
+ all_patient_slices.extend(patient_slices)
755
+
756
+ print(f"Total slices: {len(all_patient_slices)}")
757
+
758
+ if len(all_patient_slices) == 0:
759
+ raise ValueError(f"No data found for fold {fold_id}, split '{split}'")
760
+
761
+ # Create TensorFlow dataset
762
+ def data_generator():
763
+ """Generator function for tf.data.Dataset"""
764
+ for patient_id, slice_num, dataset_name in all_patient_slices:
765
+ try:
766
+ paired_input, combined_mask = self.load_single_slice(
767
+ patient_id, slice_num, dataset_name,
768
+ preprocessing, class_scenario
769
+ )
770
+ yield paired_input, combined_mask, patient_id, slice_num
771
+ except Exception as e:
772
+ print(f"Error loading {patient_id}_{slice_num}: {e}")
773
+ continue
774
+
775
+ # Create dataset
776
+ dataset = tf.data.Dataset.from_generator(
777
+ data_generator,
778
+ output_signature=(
779
+ tf.TensorSpec(shape=(256, 512, 1), dtype=tf.float32), # concatenated image
780
+ tf.TensorSpec(shape=(256, 256), dtype=tf.uint8), # multi-level mask
781
+ tf.TensorSpec(shape=(), dtype=tf.string), # patient_id
782
+ tf.TensorSpec(shape=(), dtype=tf.int32) # slice_num
783
+ )
784
+ )
785
+
786
+ # ── Cache BEFORE shuffle/batch ──────────────────────────────────────
787
+ # On epoch 1 the generator runs once and all 700 samples are stored
788
+ # in RAM (~350 MB). From epoch 2 onward no disk I/O occurs at all.
789
+ # Placing cache HERE (on unbatched, unshuffled samples) means:
790
+ # • The expensive load/decode/combine step is paid only once.
791
+ # • Shuffle re-randomises the order freshly each epoch (because
792
+ # reshuffle_each_iteration=True is the default).
793
+ # • Batch composition therefore differs every epoch as desired.
794
+ dataset = dataset.cache()
795
+
796
+ # Shuffle if training (acts on the in-RAM cache every epoch)
797
+ if shuffle and split == 'train':
798
+ dataset = dataset.shuffle(
799
+ buffer_size=len(all_patient_slices),
800
+ reshuffle_each_iteration=True # new random order each epoch
801
+ )
802
+
803
+ # Batch and prefetch
804
+ dataset = dataset.batch(batch_size)
805
+ dataset = dataset.prefetch(tf.data.AUTOTUNE)
806
+
807
+ return dataset
808
+
809
+
810
+ ###################### Testing & Validation Functions ######################
811
+
812
+ def test_data_loading():
813
+ """Test data loading functionality"""
814
+ print("\n" + "="*60)
815
+ print("TESTING DATA LOADING")
816
+ print("="*60)
817
+
818
+ config = DataConfig()
819
+
820
+ # Test 1: Create fold assignments
821
+ print("\n[TEST 1] Creating patient stratified splits...")
822
+ splitter = PatientStratifiedSplitter(config)
823
+ fold_assignments = splitter.create_patient_stratified_splits(save=True)
824
+
825
+ # Verify patient separation
826
+ is_valid = splitter.verify_patient_separation(fold_assignments)
827
+
828
+ if not is_valid:
829
+ print("❌ Patient separation verification failed!")
830
+ return False
831
+
832
+ # Test 2: Load a single slice
833
+ print("\n[TEST 2] Loading single slice...")
834
+ loader = P2DataLoader(config)
835
+
836
+ # Get a test patient from fold 0 train set
837
+ test_patient = fold_assignments['folds']['fold_0']['train_patients'][0]
838
+
839
+ # Determine which dataset this patient belongs to
840
+ if test_patient.startswith('c'):
841
+ test_dataset = 'Public_MSSEG'
842
+ test_slice = 25 # Middle of 20-46 range
843
+ else:
844
+ test_dataset = 'Local_SAI_updated'
845
+ test_slice = 10 # Middle of 8-15 range
846
+
847
+ try:
848
+ paired_input, combined_mask = loader.load_single_slice(
849
+ test_patient, test_slice, test_dataset,
850
+ 'standard', '4class'
851
+ )
852
+
853
+ print(f"✅ Loaded slice {test_patient}_{test_slice}")
854
+ print(f" Paired input shape: {paired_input.shape}")
855
+ print(f" Combined mask shape: {combined_mask.shape}")
856
+ print(f" Mask unique values: {np.unique(combined_mask)}")
857
+
858
+ except Exception as e:
859
+ print(f"❌ Failed to load slice: {e}")
860
+ return False
861
+
862
+ # Test 3: Create TensorFlow dataset
863
+ print("\n[TEST 3] Creating TensorFlow dataset...")
864
+ try:
865
+ dataset = loader.create_dataset_for_fold(
866
+ fold_id=0,
867
+ split='train',
868
+ preprocessing='standard',
869
+ class_scenario='4class',
870
+ batch_size=2,
871
+ shuffle=True
872
+ )
873
+
874
+ # Get first batch
875
+ for batch_paired, batch_masks in dataset.take(1):
876
+ print(f"✅ Created dataset")
877
+ print(f" Batch paired input shape: {batch_paired.shape}")
878
+ print(f" Batch masks shape: {batch_masks.shape}")
879
+ print(f" Paired input dtype: {batch_paired.dtype}")
880
+ print(f" Masks dtype: {batch_masks.dtype}")
881
+
882
+ except Exception as e:
883
+ print(f"❌ Failed to create dataset: {e}")
884
+ return False
885
+
886
+ print("\n" + "="*60)
887
+ print("✅ ALL TESTS PASSED")
888
+ print("="*60)
889
+
890
+ return True
891
+
892
+
893
+ ###################### Main Execution ######################
894
+
895
+ if __name__ == "__main__":
896
+ # Run tests
897
+ success = test_data_loading()
898
+
899
+ if success:
900
+ print("\n" + "="*60)
901
+ print("DATA LOADER READY FOR USE")
902
+ print("="*60)
903
+ print("\nNext steps:")
904
+ print("1. Verify fold_assignments.json created in data_splits/")
905
+ print("2. Check that all file paths are correct for your system")
906
+ print("3. Proceed to model implementation")
907
+ else:
908
+ print("\n" + "="*60)
909
+ print("❌ DATA LOADER TESTS FAILED")
910
+ print("="*60)
911
+ print("\nPlease fix the issues above before proceeding")
912
+
models/for_WMH_Vent/model_training_scripts/p4_error_analysis.py ADDED
@@ -0,0 +1,1033 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ P2 Article - Error Analysis & Hard Case Ranking Module
3
+ for Ventricles and WMH Segmentation
4
+
5
+ Integrates with p4_inference.py to identify problematic slices and patients,
6
+ rank them by difficulty, and produce rich diagnostic visualizations.
7
+
8
+ Developer: Mahdi Bashiri Bawil
9
+ """
10
+
11
+ import numpy as np
12
+ import matplotlib.pyplot as plt
13
+ import matplotlib.gridspec as gridspec
14
+ import matplotlib.patches as mpatches
15
+ from matplotlib.colors import ListedColormap, BoundaryNorm
16
+ import pandas as pd
17
+ import json
18
+ from pathlib import Path
19
+ from collections import defaultdict
20
+ from scipy.ndimage import binary_erosion, label as scipy_label
21
+ from tqdm import tqdm
22
+
23
+
24
+ # ─────────────────────────────────────────────────────────────────────────────
25
+ # SECTION 1 — Slice-level metric computation
26
+ # ─────────────────────────────────────────────────────────────────────────────
27
+
28
+ def _dice_binary(gt_bin, pred_bin):
29
+ """Dice for a single binary mask pair. Returns NaN if both are empty."""
30
+ tp = np.sum(gt_bin & pred_bin)
31
+ denom = np.sum(gt_bin) + np.sum(pred_bin)
32
+ if denom == 0:
33
+ return np.nan # class truly absent — not a failure
34
+ return float(2 * tp / (denom + 1e-7))
35
+
36
+
37
+ def _iou_binary(gt_bin, pred_bin):
38
+ tp = np.sum(gt_bin & pred_bin)
39
+ denom = np.sum(gt_bin | pred_bin)
40
+ if denom == 0:
41
+ return np.nan
42
+ return float(tp / (denom + 1e-7))
43
+
44
+
45
+ def _precision_recall(gt_bin, pred_bin):
46
+ tp = np.sum(gt_bin & pred_bin)
47
+ fp = np.sum(~gt_bin & pred_bin)
48
+ fn = np.sum(gt_bin & ~pred_bin)
49
+ precision = float(tp / (tp + fp + 1e-7))
50
+ recall = float(tp / (tp + fn + 1e-7))
51
+ return precision, recall
52
+
53
+
54
+ def _false_positive_volume(gt_bin, pred_bin):
55
+ """Fraction of predicted pixels that are false positives."""
56
+ fp = np.sum(~gt_bin & pred_bin)
57
+ total_pred = np.sum(pred_bin)
58
+ if total_pred == 0:
59
+ return 0.0
60
+ return float(fp / total_pred)
61
+
62
+
63
+ def _false_negative_volume(gt_bin, pred_bin):
64
+ """Fraction of GT pixels that are missed."""
65
+ fn = np.sum(gt_bin & ~pred_bin)
66
+ total_gt = np.sum(gt_bin)
67
+ if total_gt == 0:
68
+ return 0.0
69
+ return float(fn / total_gt)
70
+
71
+
72
+ def _gt_load(gt_hw, class_idx):
73
+ """Return binary GT mask for a specific class from a (H,W) label map."""
74
+ return gt_hw == class_idx
75
+
76
+
77
+ def _pred_load(pred_hw, class_idx):
78
+ return pred_hw == class_idx
79
+
80
+
81
+ def compute_slice_metrics(gt_hw, pred_hw, num_classes, class_names,
82
+ mean_confidence=None):
83
+ """
84
+ Compute per-class and summary metrics for a single 2-D slice.
85
+
86
+ Parameters
87
+ ----------
88
+ gt_hw : np.ndarray (H, W) — integer label map (ground truth)
89
+ pred_hw : np.ndarray (H, W) — integer label map (prediction)
90
+ num_classes : int
91
+ class_names : list[str]
92
+ mean_confidence : float | None — mean max-softmax probability for the slice
93
+
94
+ Returns
95
+ -------
96
+ dict with per-class and aggregate metrics
97
+ """
98
+ results = {}
99
+ dice_values = []
100
+ iou_values = []
101
+
102
+ for cls in range(num_classes):
103
+ gt_bin = _gt_load(gt_hw, cls)
104
+ pred_bin = _pred_load(pred_hw, cls)
105
+
106
+ dice = _dice_binary(gt_bin, pred_bin)
107
+ iou = _iou_binary(gt_bin, pred_bin)
108
+ prec, rec = _precision_recall(gt_bin, pred_bin)
109
+ fpr = _false_positive_volume(gt_bin, pred_bin)
110
+ fnr = _false_negative_volume(gt_bin, pred_bin)
111
+
112
+ gt_px = int(np.sum(gt_bin))
113
+ pred_px = int(np.sum(pred_bin))
114
+ error_px = int(np.sum(gt_bin != pred_bin))
115
+
116
+ results[class_names[cls]] = {
117
+ 'dice': dice,
118
+ 'iou': iou,
119
+ 'precision': prec,
120
+ 'recall': rec,
121
+ 'fp_rate': fpr,
122
+ 'fn_rate': fnr,
123
+ 'gt_pixels': gt_px,
124
+ 'pred_pixels': pred_px,
125
+ 'error_pixels': error_px,
126
+ }
127
+
128
+ if not np.isnan(dice):
129
+ dice_values.append(dice)
130
+ if not np.isnan(iou):
131
+ iou_values.append(iou)
132
+
133
+ # Pixel-level error rate (ignoring class)
134
+ total_px = gt_hw.size
135
+ wrong_px = int(np.sum(gt_hw != pred_hw))
136
+ error_rate = wrong_px / total_px
137
+
138
+ # Focus on foreground classes only (skip background=0) for composite score
139
+ fg_dice = []
140
+ for cls in range(1, num_classes):
141
+ d = results[class_names[cls]]['dice']
142
+ if not np.isnan(d):
143
+ fg_dice.append(d)
144
+
145
+ mean_fg_dice = float(np.mean(fg_dice)) if fg_dice else np.nan
146
+ min_fg_dice = float(np.min(fg_dice)) if fg_dice else np.nan
147
+
148
+ results['_summary'] = {
149
+ 'error_rate': error_rate,
150
+ 'wrong_pixels': wrong_px,
151
+ 'total_pixels': total_px,
152
+ 'mean_fg_dice': mean_fg_dice,
153
+ 'min_fg_dice': min_fg_dice,
154
+ 'mean_confidence': mean_confidence,
155
+ }
156
+
157
+ return results
158
+
159
+
160
+ # ─────────────────────────────────────────────────────────────────────────────
161
+ # SECTION 2 — Build slice-level and patient-level tables
162
+ # ─────────────────────────────────────────────────────────────────────────────
163
+
164
+ def build_error_tables(patient_results, num_classes, class_names):
165
+ """
166
+ Iterate over all patients / slices stored in patient_results
167
+ (the dict returned by run_inference) and build:
168
+
169
+ - slice_records : list of dicts, one per 2-D slice
170
+ - patient_records : list of dicts, one per patient (aggregated)
171
+
172
+ Parameters
173
+ ----------
174
+ patient_results : dict
175
+ {patient_id: {'predictions', 'ground_truths', 'probabilities',
176
+ 'flairs', 'slice_indices'}}
177
+ num_classes : int
178
+ class_names : list[str]
179
+
180
+ Returns
181
+ -------
182
+ slice_df : pd.DataFrame
183
+ patient_df : pd.DataFrame
184
+ """
185
+ slice_records = []
186
+ patient_records = []
187
+
188
+ for patient_id, data in tqdm(patient_results.items(),
189
+ desc="Building error tables"):
190
+ order = np.argsort(data['slice_indices'])
191
+
192
+ preds = np.array(data['predictions'])[order] # (S, H, W)
193
+ gts = np.array(data['ground_truths'])[order] # (S, H, W, C) or (S, H, W)
194
+ probs = np.array(data['probabilities'])[order] # (S, H, W)
195
+ slices = np.array(data['slice_indices'])[order] # (S,)
196
+
197
+ # Ground truth may be one-hot: collapse to label map
198
+ if gts.ndim == 4:
199
+ gts = np.argmax(gts, axis=-1)
200
+
201
+ patient_fg_dice = defaultdict(list)
202
+ patient_error_rates = []
203
+
204
+ for i, slice_num in enumerate(slices):
205
+ gt_hw = gts[i]
206
+ pred_hw = preds[i]
207
+ prob_hw = probs[i]
208
+
209
+ mean_conf = float(np.mean(prob_hw))
210
+ m = compute_slice_metrics(gt_hw, pred_hw, num_classes,
211
+ class_names, mean_confidence=mean_conf)
212
+
213
+ row = {
214
+ 'patient_id': patient_id,
215
+ 'slice_num': int(slice_num),
216
+ 'slice_id': f"{patient_id}_slice_{int(slice_num):03d}",
217
+ 'error_rate': m['_summary']['error_rate'],
218
+ 'wrong_pixels': m['_summary']['wrong_pixels'],
219
+ 'mean_fg_dice': m['_summary']['mean_fg_dice'],
220
+ 'min_fg_dice': m['_summary']['min_fg_dice'],
221
+ 'mean_confidence': m['_summary']['mean_confidence'],
222
+ }
223
+
224
+ for cls in range(num_classes):
225
+ cname = class_names[cls]
226
+ cm = m[cname]
227
+ prefix = cname.lower().replace(' ', '_')
228
+ row[f'{prefix}_dice'] = cm['dice']
229
+ row[f'{prefix}_iou'] = cm['iou']
230
+ row[f'{prefix}_precision'] = cm['precision']
231
+ row[f'{prefix}_recall'] = cm['recall']
232
+ row[f'{prefix}_fp_rate'] = cm['fp_rate']
233
+ row[f'{prefix}_fn_rate'] = cm['fn_rate']
234
+ row[f'{prefix}_gt_px'] = cm['gt_pixels']
235
+ row[f'{prefix}_pred_px'] = cm['pred_pixels']
236
+ row[f'{prefix}_err_px'] = cm['error_pixels']
237
+
238
+ if cls > 0 and not np.isnan(cm['dice']):
239
+ patient_fg_dice[cname].append(cm['dice'])
240
+
241
+ patient_error_rates.append(m['_summary']['error_rate'])
242
+ slice_records.append(row)
243
+
244
+ # ── Patient summary ──
245
+ pat_row = {'patient_id': patient_id,
246
+ 'n_slices': len(slices),
247
+ 'mean_error_rate': float(np.mean(patient_error_rates))}
248
+ for cls in range(1, num_classes):
249
+ cname = class_names[cls]
250
+ vals = patient_fg_dice[cname]
251
+ prefix = cname.lower().replace(' ', '_')
252
+ pat_row[f'{prefix}_mean_dice'] = float(np.mean(vals)) if vals else np.nan
253
+ pat_row[f'{prefix}_std_dice'] = float(np.std(vals)) if vals else np.nan
254
+ pat_row[f'{prefix}_min_dice'] = float(np.min(vals)) if vals else np.nan
255
+
256
+ # Composite: mean of per-class mean dices (foreground only)
257
+ fg_means = [pat_row[f"{class_names[c].lower().replace(' ', '_')}_mean_dice"]
258
+ for c in range(1, num_classes)
259
+ if not np.isnan(pat_row.get(
260
+ f"{class_names[c].lower().replace(' ','_')}_mean_dice", np.nan))]
261
+ pat_row['composite_dice'] = float(np.mean(fg_means)) if fg_means else np.nan
262
+
263
+ patient_records.append(pat_row)
264
+
265
+ slice_df = pd.DataFrame(slice_records)
266
+ patient_df = pd.DataFrame(patient_records)
267
+
268
+ return slice_df, patient_df
269
+
270
+
271
+ # ─────────────────────────────────────────────────────────────────────────────
272
+ # SECTION 3 — Composite difficulty score & ranking
273
+ # ─────────────────────────────────────────────────────────────────────────────
274
+
275
+ def rank_slices(slice_df, class_names, num_classes,
276
+ fg_dice_weight=0.6, error_rate_weight=0.2,
277
+ confidence_weight=0.2):
278
+ """
279
+ Add a `difficulty_score` column to slice_df (higher = harder).
280
+
281
+ Score = fg_dice_weight * (1 - mean_fg_dice)
282
+ + error_rate_weight * error_rate
283
+ + confidence_weight * (1 - mean_confidence)
284
+
285
+ NaN dice (class absent in GT) is neutral (0.5) so it doesn't
286
+ inflate difficulty for slices where the class just doesn't exist.
287
+ """
288
+ df = slice_df.copy()
289
+
290
+ # Fill NaN mean_fg_dice with 0.5 for scoring (class not present → neutral)
291
+ fg_dice_filled = df['mean_fg_dice'].fillna(0.5)
292
+ conf_filled = df['mean_confidence'].fillna(0.5)
293
+
294
+ df['difficulty_score'] = (
295
+ fg_dice_weight * (1 - fg_dice_filled) +
296
+ error_rate_weight * df['error_rate'] +
297
+ confidence_weight * (1 - conf_filled)
298
+ )
299
+
300
+ df = df.sort_values('difficulty_score', ascending=False).reset_index(drop=True)
301
+ df['difficulty_rank'] = df.index + 1
302
+
303
+ return df
304
+
305
+
306
+ def rank_patients(patient_df):
307
+ """Sort patients from hardest to easiest (lowest composite dice first)."""
308
+ df = patient_df.copy()
309
+ df = df.sort_values('composite_dice', ascending=True).reset_index(drop=True)
310
+ df['difficulty_rank'] = df.index + 1
311
+ return df
312
+
313
+
314
+ # ─────────────────────────────────────────────────────────────────────────────
315
+ # SECTION 4 — Visualization helpers
316
+ # ─────────────────────────────────────────────────────────────────────────────
317
+
318
+ CLASS_COLORS_3 = ['black', '#2196F3', '#F44336'] # BG, Vent, WMH
319
+ CLASS_COLORS_4 = ['black', '#2196F3', '#4CAF50', '#F44336'] # BG, Vent, NormWMH, AbWMH
320
+
321
+ ERROR_CMAP = ListedColormap(['#1A1A1A', # correct background
322
+ '#FF5722', # FP (pred fg, gt bg)
323
+ '#03A9F4', # FN (gt fg, pred bg)
324
+ '#FFEB3B']) # class confusion
325
+
326
+
327
+ def _get_class_cmap(num_classes):
328
+ colors = CLASS_COLORS_3 if num_classes == 3 else CLASS_COLORS_4
329
+ cmap = ListedColormap(colors)
330
+ norm = BoundaryNorm(range(num_classes + 1), num_classes)
331
+ return cmap, norm
332
+
333
+
334
+ def _build_error_rgb(gt_hw, pred_hw, num_classes):
335
+ """
336
+ Build a pixel-wise error classification map:
337
+ 0 = correct
338
+ 1 = false positive (model predicts fg, GT is bg)
339
+ 2 = false negative (GT is fg, model predicts bg)
340
+ 3 = class confusion (both fg but wrong class)
341
+ """
342
+ gt_fg = gt_hw > 0
343
+ pred_fg = pred_hw > 0
344
+
345
+ err = np.zeros_like(gt_hw, dtype=np.uint8)
346
+ err[~gt_fg & pred_fg] = 1 # FP
347
+ err[gt_fg & ~pred_fg] = 2 # FN
348
+ err[gt_fg & pred_fg & (gt_hw != pred_hw)] = 3 # confusion
349
+ return err
350
+
351
+
352
+ def _add_class_legend(ax, class_names, num_classes):
353
+ colors = CLASS_COLORS_3 if num_classes == 3 else CLASS_COLORS_4
354
+ patches = [mpatches.Patch(color=colors[i], label=class_names[i])
355
+ for i in range(num_classes)]
356
+ ax.legend(handles=patches, loc='lower right', fontsize=7,
357
+ framealpha=0.8, markerscale=0.8)
358
+
359
+
360
+ # ─────────────────────────────────────────────────────────────────────────────
361
+ # SECTION 5 — Diagnostic slice visualization
362
+ # ─────────────────────────────────────────────────────────────────────────────
363
+
364
+ def visualize_hard_slice(flair, gt_hw, pred_hw, prob_hw,
365
+ slice_metrics_row, class_names, num_classes,
366
+ save_path, rank=None):
367
+ """
368
+ Create a rich 3-row diagnostic panel for a single hard slice.
369
+
370
+ Row 1 : FLAIR | GT mask | Predicted mask | Overlay (GT contour on FLAIR)
371
+ Row 2 : Confidence map | Error type map | GT vs Pred contour overlay
372
+ Row 3 : Per-class dice bar chart | FP/FN summary table
373
+ """
374
+ cmap_cls, norm_cls = _get_class_cmap(num_classes)
375
+ err_map = _build_error_rgb(gt_hw, pred_hw, num_classes)
376
+
377
+ patient_id = slice_metrics_row.get('patient_id', '?')
378
+ slice_num = slice_metrics_row.get('slice_num', '?')
379
+ diff_score = slice_metrics_row.get('difficulty_score', float('nan'))
380
+ diff_rank = slice_metrics_row.get('difficulty_rank', rank)
381
+ mean_conf = slice_metrics_row.get('mean_confidence', float('nan'))
382
+ mean_fg_d = slice_metrics_row.get('mean_fg_dice', float('nan'))
383
+
384
+ fig = plt.figure(figsize=(20, 14))
385
+ fig.patch.set_facecolor('#0D0D0D')
386
+ title_str = (f"Patient: {patient_id} | Slice: {slice_num:03d} | "
387
+ f"Rank #{diff_rank} | Difficulty: {diff_score:.3f} | "
388
+ f"Mean FG Dice: {mean_fg_d:.3f} | Mean Conf: {mean_conf:.3f}")
389
+ fig.suptitle(title_str, color='white', fontsize=12, fontweight='bold', y=0.98)
390
+
391
+ gs = gridspec.GridSpec(3, 4, figure=fig,
392
+ hspace=0.35, wspace=0.25,
393
+ left=0.04, right=0.98,
394
+ top=0.93, bottom=0.04)
395
+
396
+ def styled_ax(pos):
397
+ ax = fig.add_subplot(pos)
398
+ ax.set_facecolor('#0D0D0D')
399
+ ax.tick_params(colors='white')
400
+ for spine in ax.spines.values():
401
+ spine.set_edgecolor('#444')
402
+ return ax
403
+
404
+ # ── Row 0 ──────────────────────────────────────────────────────────────
405
+ ax00 = styled_ax(gs[0, 0])
406
+ ax00.imshow(flair, cmap='gray', vmin=flair.min(), vmax=flair.max())
407
+ ax00.set_title('FLAIR', color='white', fontsize=10)
408
+ ax00.axis('off')
409
+
410
+ ax01 = styled_ax(gs[0, 1])
411
+ ax01.imshow(gt_hw, cmap=cmap_cls, norm=norm_cls, interpolation='nearest')
412
+ ax01.set_title('Ground Truth', color='white', fontsize=10)
413
+ ax01.axis('off')
414
+ _add_class_legend(ax01, class_names, num_classes)
415
+
416
+ ax02 = styled_ax(gs[0, 2])
417
+ ax02.imshow(pred_hw, cmap=cmap_cls, norm=norm_cls, interpolation='nearest')
418
+ ax02.set_title('Prediction', color='white', fontsize=10)
419
+ ax02.axis('off')
420
+ _add_class_legend(ax02, class_names, num_classes)
421
+
422
+ # GT contour overlay on FLAIR
423
+ ax03 = styled_ax(gs[0, 3])
424
+ ax03.imshow(flair, cmap='gray', vmin=flair.min(), vmax=flair.max())
425
+ colors_cls = CLASS_COLORS_3 if num_classes == 3 else CLASS_COLORS_4
426
+ for cls in range(1, num_classes):
427
+ gt_bin = (gt_hw == cls).astype(np.uint8)
428
+ pred_bin = (pred_hw == cls).astype(np.uint8)
429
+ if gt_bin.any():
430
+ ax03.contour(gt_bin, levels=[0.5], colors=[colors_cls[cls]],
431
+ linewidths=1.5, linestyles='solid')
432
+ if pred_bin.any():
433
+ ax03.contour(pred_bin, levels=[0.5], colors=[colors_cls[cls]],
434
+ linewidths=1.2, linestyles='dashed')
435
+ gt_patch = mpatches.Patch(color='white', linestyle='solid', label='GT (solid)')
436
+ pred_patch = mpatches.Patch(color='white', linestyle='dashed', label='Pred (dashed)')
437
+ ax03.legend(handles=[gt_patch, pred_patch], loc='lower right',
438
+ fontsize=7, framealpha=0.7)
439
+ ax03.set_title('GT vs Pred Contours', color='white', fontsize=10)
440
+ ax03.axis('off')
441
+
442
+ # ── Row 1 ──────────────────────────────────────────────────────────────
443
+ ax10 = styled_ax(gs[1, 0])
444
+ im_conf = ax10.imshow(prob_hw, cmap='plasma', vmin=0, vmax=1)
445
+ plt.colorbar(im_conf, ax=ax10, fraction=0.046, pad=0.04).ax.yaxis.set_tick_params(color='white')
446
+ ax10.set_title('Confidence Map', color='white', fontsize=10)
447
+ ax10.axis('off')
448
+
449
+ # Low-confidence overlay on FLAIR
450
+ ax11 = styled_ax(gs[1, 1])
451
+ ax11.imshow(flair, cmap='gray')
452
+ low_conf_mask = prob_hw < 0.5
453
+ overlay = np.zeros((*flair.shape, 4))
454
+ overlay[low_conf_mask] = [1, 0.3, 0, 0.55] # orange-red for uncertain regions
455
+ ax11.imshow(overlay)
456
+ ax11.set_title('Low-Confidence Regions (<0.5)', color='white', fontsize=10)
457
+ ax11.axis('off')
458
+
459
+ ax12 = styled_ax(gs[1, 2])
460
+ err_colors = ['#1A1A1A', '#FF5722', '#03A9F4', '#FFEB3B']
461
+ err_cmap = ListedColormap(err_colors)
462
+ err_norm = BoundaryNorm([0, 1, 2, 3, 4], 4)
463
+ ax12.imshow(err_map, cmap=err_cmap, norm=err_norm, interpolation='nearest')
464
+ patches_err = [
465
+ mpatches.Patch(color='#1A1A1A', label='Correct'),
466
+ mpatches.Patch(color='#FF5722', label='False Positive'),
467
+ mpatches.Patch(color='#03A9F4', label='False Negative'),
468
+ mpatches.Patch(color='#FFEB3B', label='Class Confusion'),
469
+ ]
470
+ ax12.legend(handles=patches_err, loc='lower right', fontsize=6.5, framealpha=0.8)
471
+ ax12.set_title('Error Type Map', color='white', fontsize=10)
472
+ ax12.axis('off')
473
+
474
+ # FLAIR + error overlay
475
+ ax13 = styled_ax(gs[1, 3])
476
+ flair_rgb = np.stack([flair] * 3, axis=-1)
477
+ # Normalise 0-1
478
+ flair_rgb = (flair_rgb - flair_rgb.min()) / (flair_rgb.max() - flair_rgb.min() + 1e-7)
479
+ err_overlay = flair_rgb.copy()
480
+ err_overlay[err_map == 1] = [1.0, 0.34, 0.13] # FP
481
+ err_overlay[err_map == 2] = [0.01, 0.66, 0.96] # FN
482
+ err_overlay[err_map == 3] = [1.0, 0.92, 0.23] # confusion
483
+ ax13.imshow(err_overlay)
484
+ ax13.set_title('FLAIR + Error Overlay', color='white', fontsize=10)
485
+ ax13.axis('off')
486
+
487
+ # ── Row 2: metrics ─────────────────────────────────────────────────────
488
+ ax20 = styled_ax(gs[2, 0:2])
489
+ ax20.set_facecolor('#111')
490
+
491
+ bar_labels = []
492
+ bar_dice = []
493
+ bar_colors = []
494
+ for cls in range(1, num_classes):
495
+ cname = class_names[cls]
496
+ prefix = cname.lower().replace(' ', '_')
497
+ d = slice_metrics_row.get(f'{prefix}_dice', np.nan)
498
+ bar_labels.append(cname)
499
+ bar_dice.append(d if not np.isnan(d) else 0)
500
+ bar_colors.append(colors_cls[cls])
501
+
502
+ x = np.arange(len(bar_labels))
503
+ bars = ax20.bar(x, bar_dice, color=bar_colors, edgecolor='white',
504
+ linewidth=0.8, width=0.5)
505
+ ax20.axhline(0.5, color='red', linestyle='--', linewidth=1, label='Threshold 0.5')
506
+ ax20.axhline(0.8, color='yellow', linestyle='--', linewidth=1, label='Good 0.8')
507
+ ax20.set_xticks(x)
508
+ ax20.set_xticklabels(bar_labels, color='white', fontsize=9)
509
+ ax20.set_ylim(0, 1.05)
510
+ ax20.set_ylabel('Dice Score', color='white', fontsize=9)
511
+ ax20.set_title('Per-Class Dice', color='white', fontsize=10)
512
+ ax20.tick_params(axis='y', colors='white')
513
+ ax20.legend(fontsize=7, labelcolor='white', framealpha=0.3)
514
+ for bar, val in zip(bars, bar_dice):
515
+ ax20.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.02,
516
+ f'{val:.3f}', ha='center', color='white', fontsize=9)
517
+
518
+ # Table: per-class FP/FN/precision/recall
519
+ ax21 = styled_ax(gs[2, 2:4])
520
+ ax21.axis('off')
521
+
522
+ col_labels = ['Class', 'Dice', 'Prec', 'Recall', 'FP rate', 'FN rate',
523
+ 'GT px', 'Pred px']
524
+ table_data = []
525
+ for cls in range(1, num_classes):
526
+ cname = class_names[cls]
527
+ prefix = cname.lower().replace(' ', '_')
528
+ def _g(k):
529
+ v = slice_metrics_row.get(f'{prefix}_{k}', np.nan)
530
+ return f'{v:.3f}' if not np.isnan(v) else 'N/A'
531
+ table_data.append([
532
+ cname,
533
+ _g('dice'), _g('precision'), _g('recall'),
534
+ _g('fp_rate'), _g('fn_rate'),
535
+ str(int(slice_metrics_row.get(f'{prefix}_gt_px', 0))),
536
+ str(int(slice_metrics_row.get(f'{prefix}_pred_px', 0))),
537
+ ])
538
+
539
+ tbl = ax21.table(cellText=table_data, colLabels=col_labels,
540
+ cellLoc='center', loc='center')
541
+ tbl.auto_set_font_size(False)
542
+ tbl.set_fontsize(8)
543
+ tbl.scale(1, 1.6)
544
+ for (r, c), cell in tbl.get_celld().items():
545
+ cell.set_edgecolor('#444')
546
+ if r == 0:
547
+ cell.set_facecolor('#2C2C2C')
548
+ cell.set_text_props(color='white', fontweight='bold')
549
+ else:
550
+ cell.set_facecolor('#1A1A1A')
551
+ cell.set_text_props(color='white')
552
+ ax21.set_title('Per-Class Metrics Summary', color='white', fontsize=10, pad=8)
553
+
554
+ plt.savefig(save_path, dpi=130, bbox_inches='tight',
555
+ facecolor=fig.get_facecolor())
556
+ plt.close(fig)
557
+
558
+
559
+ # ─────────────────────────────────────────────────────────────────────────────
560
+ # SECTION 6 — Patient-level summary visualization
561
+ # ─────────────────────────────────────────────────────────────────────────────
562
+
563
+ def visualize_patient_summary(patient_id, patient_data, slice_df_patient,
564
+ class_names, num_classes, save_path):
565
+ """
566
+ One-page summary for a single patient showing:
567
+ - Dice scores across all slices (line plot per class)
568
+ - Confidence vs. error rate scatter
569
+ - Per-slice FP / FN bar chart
570
+ - Overall dice distribution box plots
571
+ """
572
+ order = np.argsort(patient_data['slice_indices'])
573
+ slices = np.array(patient_data['slice_indices'])[order]
574
+ n_slices = len(slices)
575
+
576
+ fig, axes = plt.subplots(2, 2, figsize=(18, 10))
577
+ fig.patch.set_facecolor('#0D0D0D')
578
+ fig.suptitle(f'Patient Summary | ID: {patient_id} | {n_slices} slices',
579
+ color='white', fontsize=13, fontweight='bold')
580
+
581
+ colors_cls = CLASS_COLORS_3 if num_classes == 3 else CLASS_COLORS_4
582
+
583
+ df = slice_df_patient.sort_values('slice_num').reset_index(drop=True)
584
+
585
+ # ── Plot 1: Per-slice Dice per class ──────────────────────────────────
586
+ ax = axes[0, 0]
587
+ ax.set_facecolor('#111')
588
+ for cls in range(1, num_classes):
589
+ cname = class_names[cls]
590
+ prefix = cname.lower().replace(' ', '_')
591
+ col = f'{prefix}_dice'
592
+ if col in df.columns:
593
+ valid = df[col].notna()
594
+ ax.plot(df.loc[valid, 'slice_num'], df.loc[valid, col],
595
+ color=colors_cls[cls], linewidth=1.5,
596
+ marker='o', markersize=3, label=cname)
597
+ ax.axhline(0.5, color='red', linestyle='--', linewidth=0.8, alpha=0.7)
598
+ ax.axhline(0.8, color='yellow', linestyle='--', linewidth=0.8, alpha=0.7)
599
+ ax.set_xlabel('Slice Number', color='white')
600
+ ax.set_ylabel('Dice Score', color='white')
601
+ ax.set_title('Per-Slice Dice by Class', color='white', fontsize=10)
602
+ ax.legend(fontsize=8, labelcolor='white', framealpha=0.3)
603
+ ax.tick_params(colors='white')
604
+ for spine in ax.spines.values():
605
+ spine.set_edgecolor('#444')
606
+ ax.set_ylim(0, 1.05)
607
+
608
+ # ── Plot 2: Confidence vs Error rate scatter ───────────────────────────
609
+ ax = axes[0, 1]
610
+ ax.set_facecolor('#111')
611
+ sc = ax.scatter(df['mean_confidence'], df['error_rate'],
612
+ c=df['mean_fg_dice'].fillna(0.5),
613
+ cmap='RdYlGn', vmin=0, vmax=1,
614
+ s=50, edgecolors='white', linewidths=0.3, alpha=0.85)
615
+ cbar = plt.colorbar(sc, ax=ax)
616
+ cbar.set_label('Mean FG Dice', color='white')
617
+ cbar.ax.yaxis.set_tick_params(color='white')
618
+ plt.setp(plt.getp(cbar.ax.axes, 'yticklabels'), color='white')
619
+ ax.set_xlabel('Mean Confidence', color='white')
620
+ ax.set_ylabel('Pixel Error Rate', color='white')
621
+ ax.set_title('Confidence vs Error Rate\n(colour = Mean FG Dice)',
622
+ color='white', fontsize=10)
623
+ ax.tick_params(colors='white')
624
+ for spine in ax.spines.values():
625
+ spine.set_edgecolor('#444')
626
+
627
+ # Annotate worst 3 slices
628
+ worst3 = df.nlargest(3, 'difficulty_score') if 'difficulty_score' in df.columns \
629
+ else df.nlargest(3, 'error_rate')
630
+ for _, row in worst3.iterrows():
631
+ ax.annotate(f"sl{int(row['slice_num']):03d}",
632
+ (row['mean_confidence'], row['error_rate']),
633
+ textcoords="offset points", xytext=(5, 5),
634
+ fontsize=7, color='white')
635
+
636
+ # ── Plot 3: FP / FN pixel rates per slice ─────────────────────────────
637
+ ax = axes[1, 0]
638
+ ax.set_facecolor('#111')
639
+ x = df['slice_num'].values
640
+ # Use WMH class (last foreground class) as primary interest
641
+ cls_main = num_classes - 1
642
+ prefix_m = class_names[cls_main].lower().replace(' ', '_')
643
+ fp_col = f'{prefix_m}_fp_rate'
644
+ fn_col = f'{prefix_m}_fn_rate'
645
+
646
+ if fp_col in df.columns and fn_col in df.columns:
647
+ width = 0.4
648
+ ax.bar(x - width/2, df[fp_col].fillna(0), width=width,
649
+ color='#FF5722', alpha=0.8, label='FP Rate')
650
+ ax.bar(x + width/2, df[fn_col].fillna(0), width=width,
651
+ color='#03A9F4', alpha=0.8, label='FN Rate')
652
+ ax.set_xlabel('Slice Number', color='white')
653
+ ax.set_ylabel('Rate', color='white')
654
+ ax.set_title(f'FP / FN Rate per Slice [{class_names[cls_main]}]',
655
+ color='white', fontsize=10)
656
+ ax.legend(fontsize=8, labelcolor='white', framealpha=0.3)
657
+ ax.tick_params(colors='white')
658
+ for spine in ax.spines.values():
659
+ spine.set_edgecolor('#444')
660
+
661
+ # ── Plot 4: Dice distribution box plots ───────────────────────────────
662
+ ax = axes[1, 1]
663
+ ax.set_facecolor('#111')
664
+ box_data = []
665
+ box_labels = []
666
+ box_colors = []
667
+ for cls in range(1, num_classes):
668
+ cname = class_names[cls]
669
+ prefix = cname.lower().replace(' ', '_')
670
+ col = f'{prefix}_dice'
671
+ vals = df[col].dropna().values if col in df.columns else np.array([])
672
+ box_data.append(vals)
673
+ box_labels.append(cname)
674
+ box_colors.append(colors_cls[cls])
675
+
676
+ bp = ax.boxplot(box_data, patch_artist=True,
677
+ medianprops=dict(color='white', linewidth=2))
678
+ for patch, color in zip(bp['boxes'], box_colors):
679
+ patch.set_facecolor(color)
680
+ patch.set_alpha(0.7)
681
+ for element in ['whiskers', 'caps', 'fliers']:
682
+ for item in bp[element]:
683
+ item.set_color('white')
684
+
685
+ ax.set_xticklabels(box_labels, color='white')
686
+ ax.set_ylabel('Dice Score', color='white')
687
+ ax.set_title('Dice Score Distribution per Class', color='white', fontsize=10)
688
+ ax.axhline(0.5, color='red', linestyle='--', linewidth=0.8, alpha=0.7)
689
+ ax.axhline(0.8, color='yellow', linestyle='--', linewidth=0.8, alpha=0.7)
690
+ ax.tick_params(colors='white')
691
+ for spine in ax.spines.values():
692
+ spine.set_edgecolor('#444')
693
+ ax.set_ylim(0, 1.05)
694
+
695
+ plt.tight_layout(rect=[0, 0, 1, 0.95])
696
+ plt.savefig(save_path, dpi=120, bbox_inches='tight',
697
+ facecolor=fig.get_facecolor())
698
+ plt.close(fig)
699
+
700
+
701
+ # ─────────────────────────────────────────────────────────────────────────────
702
+ # SECTION 7 — Dataset-level overview visualizations
703
+ # ─────────────────────────────────────────────────────────────────────────────
704
+
705
+ def visualize_dataset_overview(slice_df, patient_df, class_names,
706
+ num_classes, save_dir):
707
+ """
708
+ Global overview plots saved to save_dir/overview/:
709
+ 1. Dice distribution across all slices (violin per class)
710
+ 2. Patient ranking bar chart (composite dice)
711
+ 3. Error rate histogram
712
+ 4. Confidence vs dice scatter (all slices)
713
+ 5. Difficulty score distribution
714
+ """
715
+ overview_dir = Path(save_dir) / 'overview'
716
+ overview_dir.mkdir(parents=True, exist_ok=True)
717
+
718
+ colors_cls = CLASS_COLORS_3 if num_classes == 3 else CLASS_COLORS_4
719
+
720
+ # ── 1. Dice violin ────────────────────────────────────────────────────
721
+ fig, ax = plt.subplots(figsize=(10, 6))
722
+ fig.patch.set_facecolor('#0D0D0D')
723
+ ax.set_facecolor('#111')
724
+
725
+ violin_data = []
726
+ violin_labels = []
727
+ for cls in range(1, num_classes):
728
+ cname = class_names[cls]
729
+ prefix = cname.lower().replace(' ', '_')
730
+ col = f'{prefix}_dice'
731
+ vals = slice_df[col].dropna().values if col in slice_df.columns else np.array([])
732
+ violin_data.append(vals)
733
+ violin_labels.append(cname)
734
+
735
+ parts = ax.violinplot(violin_data, showmedians=True, showextrema=True)
736
+ for i, (pc, color) in enumerate(zip(parts['bodies'],
737
+ [colors_cls[c] for c in range(1, num_classes)])):
738
+ pc.set_facecolor(color)
739
+ pc.set_alpha(0.7)
740
+ parts['cmedians'].set_colors('white')
741
+ parts['cmaxes'].set_colors('#aaa')
742
+ parts['cmins'].set_colors('#aaa')
743
+ parts['cbars'].set_colors('#aaa')
744
+
745
+ ax.set_xticks(range(1, len(violin_labels) + 1))
746
+ ax.set_xticklabels(violin_labels, color='white')
747
+ ax.axhline(0.5, color='red', linestyle='--', linewidth=0.9, label='0.5 threshold')
748
+ ax.axhline(0.8, color='yellow', linestyle='--', linewidth=0.9, label='0.8 target')
749
+ ax.set_ylabel('Dice Score', color='white')
750
+ ax.set_title('Dice Distribution — All Slices', color='white', fontsize=12)
751
+ ax.tick_params(colors='white')
752
+ ax.legend(fontsize=8, labelcolor='white', framealpha=0.3)
753
+ for spine in ax.spines.values():
754
+ spine.set_edgecolor('#444')
755
+ ax.set_ylim(0, 1.05)
756
+
757
+ plt.tight_layout()
758
+ plt.savefig(overview_dir / 'dice_violin_all_slices.png', dpi=130,
759
+ bbox_inches='tight', facecolor=fig.get_facecolor())
760
+ plt.close(fig)
761
+
762
+ # ── 2. Patient ranking bar chart ──────────────────────────────────────
763
+ pat_sorted = patient_df.sort_values('composite_dice').reset_index(drop=True)
764
+ n_patients = len(pat_sorted)
765
+
766
+ fig, ax = plt.subplots(figsize=(max(12, n_patients * 0.6), 5))
767
+ fig.patch.set_facecolor('#0D0D0D')
768
+ ax.set_facecolor('#111')
769
+
770
+ bar_colors = ['#F44336' if v < 0.5 else '#FFC107' if v < 0.7 else '#4CAF50'
771
+ for v in pat_sorted['composite_dice'].fillna(0)]
772
+ ax.bar(range(n_patients), pat_sorted['composite_dice'].fillna(0),
773
+ color=bar_colors, edgecolor='#333', linewidth=0.5)
774
+ ax.set_xticks(range(n_patients))
775
+ ax.set_xticklabels(pat_sorted['patient_id'], rotation=75,
776
+ ha='right', color='white', fontsize=7)
777
+ ax.axhline(0.5, color='red', linestyle='--', linewidth=0.9)
778
+ ax.axhline(0.7, color='orange', linestyle='--', linewidth=0.9)
779
+ ax.axhline(0.8, color='yellow', linestyle='--', linewidth=0.9)
780
+ ax.set_ylabel('Composite Dice (mean FG classes)', color='white')
781
+ ax.set_title('Patient Ranking — Composite Dice (worst → best)',
782
+ color='white', fontsize=12)
783
+ ax.tick_params(colors='white')
784
+ for spine in ax.spines.values():
785
+ spine.set_edgecolor('#444')
786
+ ax.set_ylim(0, 1.05)
787
+
788
+ red_p = mpatches.Patch(color='#F44336', label='< 0.5 (critical)')
789
+ orange_p = mpatches.Patch(color='#FFC107', label='0.5–0.7 (poor)')
790
+ green_p = mpatches.Patch(color='#4CAF50', label='≥ 0.7 (acceptable)')
791
+ ax.legend(handles=[red_p, orange_p, green_p],
792
+ fontsize=8, labelcolor='white', framealpha=0.3)
793
+
794
+ plt.tight_layout()
795
+ plt.savefig(overview_dir / 'patient_ranking.png', dpi=130,
796
+ bbox_inches='tight', facecolor=fig.get_facecolor())
797
+ plt.close(fig)
798
+
799
+ # ── 3. Error rate histogram ────────────────────────────────────────────
800
+ fig, ax = plt.subplots(figsize=(9, 5))
801
+ fig.patch.set_facecolor('#0D0D0D')
802
+ ax.set_facecolor('#111')
803
+ ax.hist(slice_df['error_rate'].dropna(), bins=40, color='#9C27B0',
804
+ edgecolor='white', linewidth=0.3, alpha=0.85)
805
+ ax.set_xlabel('Pixel Error Rate per Slice', color='white')
806
+ ax.set_ylabel('Count', color='white')
807
+ ax.set_title('Pixel Error Rate Distribution — All Slices', color='white', fontsize=12)
808
+ ax.tick_params(colors='white')
809
+ for spine in ax.spines.values():
810
+ spine.set_edgecolor('#444')
811
+ plt.tight_layout()
812
+ plt.savefig(overview_dir / 'error_rate_histogram.png', dpi=130,
813
+ bbox_inches='tight', facecolor=fig.get_facecolor())
814
+ plt.close(fig)
815
+
816
+ # ── 4. Confidence vs mean FG Dice scatter ─────────────────────────────
817
+ fig, ax = plt.subplots(figsize=(9, 6))
818
+ fig.patch.set_facecolor('#0D0D0D')
819
+ ax.set_facecolor('#111')
820
+ sc = ax.scatter(slice_df['mean_confidence'], slice_df['mean_fg_dice'].fillna(0),
821
+ c=slice_df['error_rate'], cmap='RdYlGn_r',
822
+ vmin=0, vmax=0.3, s=10, alpha=0.6)
823
+ cbar = plt.colorbar(sc, ax=ax)
824
+ cbar.set_label('Pixel Error Rate', color='white')
825
+ cbar.ax.yaxis.set_tick_params(color='white')
826
+ plt.setp(plt.getp(cbar.ax.axes, 'yticklabels'), color='white')
827
+ ax.set_xlabel('Mean Softmax Confidence', color='white')
828
+ ax.set_ylabel('Mean FG Dice', color='white')
829
+ ax.set_title('Confidence vs FG Dice — All Slices', color='white', fontsize=12)
830
+ ax.tick_params(colors='white')
831
+ for spine in ax.spines.values():
832
+ spine.set_edgecolor('#444')
833
+ plt.tight_layout()
834
+ plt.savefig(overview_dir / 'confidence_vs_dice_scatter.png', dpi=130,
835
+ bbox_inches='tight', facecolor=fig.get_facecolor())
836
+ plt.close(fig)
837
+
838
+ # ── 5. Difficulty score distribution ──────────────────────────────────
839
+ if 'difficulty_score' in slice_df.columns:
840
+ fig, ax = plt.subplots(figsize=(9, 5))
841
+ fig.patch.set_facecolor('#0D0D0D')
842
+ ax.set_facecolor('#111')
843
+ ax.hist(slice_df['difficulty_score'].dropna(), bins=40,
844
+ color='#FF9800', edgecolor='white', linewidth=0.3, alpha=0.85)
845
+ ax.set_xlabel('Difficulty Score', color='white')
846
+ ax.set_ylabel('Count', color='white')
847
+ ax.set_title('Difficulty Score Distribution — All Slices', color='white', fontsize=12)
848
+ ax.tick_params(colors='white')
849
+ for spine in ax.spines.values():
850
+ spine.set_edgecolor('#444')
851
+ plt.tight_layout()
852
+ plt.savefig(overview_dir / 'difficulty_score_histogram.png', dpi=130,
853
+ bbox_inches='tight', facecolor=fig.get_facecolor())
854
+ plt.close(fig)
855
+
856
+ print(f" ✅ Overview plots saved to: {overview_dir}")
857
+
858
+
859
+ # ─────────────────────────────────────────────────────────────────────────────
860
+ # SECTION 8 — Main entry point: run_error_analysis
861
+ # ─────────────────────────────────────────────────────────────────────────────
862
+
863
+ def run_error_analysis(results, config,
864
+ top_n_slices=30,
865
+ top_n_patients=10,
866
+ fg_dice_weight=0.6,
867
+ error_rate_weight=0.2,
868
+ confidence_weight=0.2):
869
+ """
870
+ Full pipeline: build tables → rank → save CSVs → generate visualizations.
871
+
872
+ Call after run_inference():
873
+ results = run_inference(config)
874
+ run_error_analysis(results, config)
875
+
876
+ Parameters
877
+ ----------
878
+ results : dict — returned by run_inference()
879
+ config : InferenceConfig
880
+ top_n_slices : int — how many hardest slices to visualize individually
881
+ top_n_patients : int — how many hardest patients to get summary plots
882
+ fg_dice_weight, error_rate_weight, confidence_weight : floats for ranking
883
+ """
884
+ patient_results = results['patients_results']
885
+ class_names = config.class_names
886
+ num_classes = config.num_classes
887
+
888
+ # Output sub-directories
889
+ error_dir = config.inference_dir / 'error_analysis'
890
+ hard_slices_dir = error_dir / 'hard_slices'
891
+ patient_summaries_dir = error_dir / 'patient_summaries'
892
+ tables_dir = error_dir / 'tables'
893
+
894
+ for d in [hard_slices_dir, patient_summaries_dir, tables_dir]:
895
+ d.mkdir(parents=True, exist_ok=True)
896
+
897
+ print("\n" + "=" * 70)
898
+ print("ERROR ANALYSIS — Building slice & patient tables")
899
+ print("=" * 70)
900
+
901
+ # ── Step 1: build tables ──────────────────────────────────────────────
902
+ slice_df, patient_df = build_error_tables(patient_results, num_classes, class_names)
903
+
904
+ # ── Step 2: rank ──────────────────────────────────────────────────────
905
+ slice_df = rank_slices(slice_df, class_names, num_classes,
906
+ fg_dice_weight, error_rate_weight, confidence_weight)
907
+ patient_df = rank_patients(patient_df)
908
+
909
+ # ── Step 3: save CSVs ─────────────────────────────────────────────────
910
+ slice_csv = tables_dir / 'slice_difficulty_ranking.csv'
911
+ patient_csv = tables_dir / 'patient_difficulty_ranking.csv'
912
+ slice_df.to_csv(slice_csv, index=False)
913
+ patient_df.to_csv(patient_csv, index=False)
914
+ print(f" ✅ Slice table → {slice_csv}")
915
+ print(f" ✅ Patient table → {patient_csv}")
916
+
917
+ # ── Step 4: dataset overview plots ────────────────────────────────────
918
+ print("\nGenerating dataset overview plots...")
919
+ visualize_dataset_overview(slice_df, patient_df, class_names,
920
+ num_classes, error_dir)
921
+
922
+ # ── Step 5: hard slice visualizations ────────────────────────────────
923
+ print(f"\nVisualizing top-{top_n_slices} hardest slices...")
924
+ hard_slices = slice_df.head(top_n_slices)
925
+
926
+ for _, row in tqdm(hard_slices.iterrows(),
927
+ total=len(hard_slices), desc="Hard slice panels"):
928
+ patient_id = row['patient_id']
929
+ slice_num = int(row['slice_num'])
930
+
931
+ data = patient_results[patient_id]
932
+ order = np.argsort(data['slice_indices'])
933
+ slices_sorted = np.array(data['slice_indices'])[order]
934
+
935
+ # Find position of this slice
936
+ pos = np.where(slices_sorted == slice_num)[0]
937
+ if len(pos) == 0:
938
+ continue
939
+ pos = pos[0]
940
+
941
+ gts = np.array(data['ground_truths'])[order]
942
+ preds = np.array(data['predictions'])[order]
943
+ probs = np.array(data['probabilities'])[order]
944
+ flairs = np.array(data['flairs'])[order]
945
+
946
+ gt_hw = gts[pos]
947
+ pred_hw = preds[pos]
948
+ prob_hw = probs[pos]
949
+ flair_hw = flairs[pos]
950
+
951
+ # Collapse one-hot GT if needed
952
+ if gt_hw.ndim == 3:
953
+ gt_hw = np.argmax(gt_hw, axis=-1)
954
+
955
+ rank = int(row['difficulty_rank'])
956
+ fname = (f"rank{rank:04d}_"
957
+ f"{patient_id}_slice{slice_num:03d}"
958
+ f"_dice{row['mean_fg_dice']:.3f}.png")
959
+ save_path = hard_slices_dir / fname
960
+
961
+ visualize_hard_slice(
962
+ flair=flair_hw,
963
+ gt_hw=gt_hw,
964
+ pred_hw=pred_hw,
965
+ prob_hw=prob_hw,
966
+ slice_metrics_row=row.to_dict(),
967
+ class_names=class_names,
968
+ num_classes=num_classes,
969
+ save_path=save_path,
970
+ rank=rank
971
+ )
972
+
973
+ print(f" ✅ Hard slice panels → {hard_slices_dir}")
974
+
975
+ # ── Step 6: patient summary visualizations ────────────────────────────
976
+ print(f"\nGenerating top-{top_n_patients} hardest patient summaries...")
977
+ hard_patients = patient_df.head(top_n_patients)
978
+
979
+ for _, pat_row in tqdm(hard_patients.iterrows(),
980
+ total=len(hard_patients), desc="Patient summaries"):
981
+ patient_id = pat_row['patient_id']
982
+ if patient_id not in patient_results:
983
+ continue
984
+
985
+ data = patient_results[patient_id]
986
+ slice_df_patient = slice_df[slice_df['patient_id'] == patient_id].copy()
987
+
988
+ rank = int(pat_row['difficulty_rank'])
989
+ comp = pat_row.get('composite_dice', float('nan'))
990
+ fname = (f"rank{rank:03d}_{patient_id}"
991
+ f"_composite{comp:.3f}.png")
992
+ save_path = patient_summaries_dir / fname
993
+
994
+ visualize_patient_summary(
995
+ patient_id=patient_id,
996
+ patient_data=data,
997
+ slice_df_patient=slice_df_patient,
998
+ class_names=class_names,
999
+ num_classes=num_classes,
1000
+ save_path=save_path
1001
+ )
1002
+
1003
+ print(f" ✅ Patient summaries → {patient_summaries_dir}")
1004
+
1005
+ # ── Step 7: print console summary ─────────────────────────────────────
1006
+ print("\n" + "=" * 70)
1007
+ print("ERROR ANALYSIS SUMMARY")
1008
+ print("=" * 70)
1009
+ print(f"\nTotal slices analysed : {len(slice_df)}")
1010
+ print(f"Total patients : {len(patient_df)}")
1011
+
1012
+ print(f"\nTop-10 Hardest Slices:")
1013
+ top10_cols = ['difficulty_rank', 'slice_id', 'mean_fg_dice',
1014
+ 'error_rate', 'mean_confidence', 'difficulty_score']
1015
+ top10_cols = [c for c in top10_cols if c in slice_df.columns]
1016
+ print(slice_df[top10_cols].head(10).to_string(index=False))
1017
+
1018
+ print(f"\nTop-10 Hardest Patients:")
1019
+ fg_dice_cols = [f"{class_names[c].lower().replace(' ', '_')}_mean_dice"
1020
+ for c in range(1, num_classes)]
1021
+ pat_cols = ['difficulty_rank', 'patient_id', 'n_slices', 'composite_dice'] + \
1022
+ [c for c in fg_dice_cols if c in patient_df.columns]
1023
+ print(patient_df[pat_cols].head(10).to_string(index=False))
1024
+
1025
+ print("\n" + "=" * 70)
1026
+ print(f"All error analysis outputs → {error_dir}")
1027
+ print("=" * 70 + "\n")
1028
+
1029
+ return {
1030
+ 'slice_df': slice_df,
1031
+ 'patient_df': patient_df,
1032
+ 'error_dir': error_dir
1033
+ }
models/for_WMH_Vent/model_training_scripts/p4_folds_results_aggregator.py ADDED
@@ -0,0 +1,611 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ P4 - All U-Net models with Adaptive Loss (WCE + UFL)
3
+
4
+ WMH and Ventricles Segmentation with U-Net Models - Journal Paper Implementation
5
+ Three-class segmentation: Background vs Ventricles vs Abnormal WMH
6
+ Professional results saving and visualization for publication
7
+
8
+ This relates to our article:
9
+ "Deep Learning-Based Neuroanatomical Profiling Reveals Detailed Brain Changes:
10
+ A Large-Scale Multiple Sclerosis Study"
11
+
12
+ Features:
13
+ - Aggregatation of all inferenced results
14
+ - Includes lesion-level (connected-component) metrics: sensitivity, precision,
15
+ F1, TP/FP/FN lesion counts (added to address reviewer R1C7)
16
+
17
+ Authors:
18
+ "Mahdi Bashiri Bawil, Mousa Shamsi, Abolhassan Shakeri Bavil"
19
+
20
+ Developer:
21
+ "Mahdi Bashiri Bawil"
22
+ """
23
+
24
+ import os
25
+ import json
26
+ import pandas as pd
27
+ import numpy as np
28
+ from pathlib import Path
29
+ import warnings
30
+ warnings.filterwarnings('ignore')
31
+
32
+
33
+ class ResultsAggregator:
34
+ """
35
+ Aggregates segmentation results across multiple variants and folds.
36
+ """
37
+
38
+ def __init__(self, base_dir='./'):
39
+ """
40
+ Initialize the aggregator.
41
+
42
+ Args:
43
+ base_dir: Base directory containing all results folders
44
+ """
45
+ self.base_dir = Path(base_dir)
46
+ self.variants = {
47
+ 1: "unet",
48
+ 2: "attnunet",
49
+ 3: "dlv3unet",
50
+ 4: "transunet"
51
+ }
52
+ self.class_names = ["Background", "Ventricles", "Abnormal_WMH"]
53
+ self.num_variants = 4
54
+ self.num_folds = 4
55
+
56
+ def find_results_folders(self):
57
+ """Find all results folders matching the naming pattern."""
58
+ results_folders = []
59
+ for variant in range(self.num_variants):
60
+ for fold in range(self.num_folds):
61
+ folder_pattern = f"results_fold_{fold}_var_{variant+1}_zscore2"
62
+ folder_path = self.base_dir / folder_pattern
63
+ if folder_path.exists():
64
+ results_folders.append({
65
+ 'variant': variant+1,
66
+ 'fold': fold,
67
+ 'path': folder_path
68
+ })
69
+ return results_folders
70
+
71
+ def load_test_metrics(self, results_folder):
72
+ """Load test metrics from JSON file."""
73
+ metrics_path = results_folder['path'] / 'inference_all_test' / 'standard_3class' / 'metrics' / 'test_metrics_complete.json'
74
+
75
+ if not metrics_path.exists():
76
+ print(f"Warning: Metrics file not found at {metrics_path}")
77
+ return None
78
+
79
+ with open(metrics_path, 'r') as f:
80
+ data = json.load(f)
81
+
82
+ return data
83
+
84
+ def load_training_summary(self, results_folder):
85
+ """Load training summary from JSON file (new format)."""
86
+ summary_path = results_folder['path'] / 'models' / 'standard_3class' / f"fold_{results_folder['fold']}" / 'training_summary.json'
87
+
88
+ if not summary_path.exists():
89
+ # Fallback to history.json if training_summary doesn't exist
90
+ return self.load_training_history(results_folder)
91
+
92
+ with open(summary_path, 'r') as f:
93
+ data = json.load(f)
94
+
95
+ return data
96
+
97
+ def load_training_history(self, results_folder):
98
+ """Load training history from JSON file (legacy support)."""
99
+ history_path = results_folder['path'] / 'models' / 'standard_3class' / f"fold_{results_folder['fold']}" / 'history.json'
100
+
101
+ if not history_path.exists():
102
+ print(f"Warning: History file not found at {history_path}")
103
+ return None
104
+
105
+ with open(history_path, 'r') as f:
106
+ data = json.load(f)
107
+
108
+ return data
109
+
110
+ def load_best_epoch_analysis(self, results_folder):
111
+ """Load best epoch analysis from JSON file (new format)."""
112
+ analysis_path = results_folder['path'] / 'models' / 'standard_3class' / f"fold_{results_folder['fold']}" / 'best_epoch_analysis.json'
113
+
114
+ if not analysis_path.exists():
115
+ return None
116
+
117
+ with open(analysis_path, 'r') as f:
118
+ data = json.load(f)
119
+
120
+ return data
121
+
122
+ def extract_test_metrics_row(self, results_folder, metrics_data):
123
+ """
124
+ Extract a row of test metrics for the summary dataframe.
125
+ Includes both voxel-level and lesion-level metrics.
126
+ """
127
+ if metrics_data is None:
128
+ return None
129
+
130
+ row = {
131
+ 'Variant': results_folder['variant'],
132
+ 'Variant_Name': self.variants[results_folder['variant']],
133
+ 'Fold': results_folder['fold'],
134
+ 'Test_Samples': metrics_data['config']['test_samples']
135
+ }
136
+
137
+ # ── Voxel-level metrics (unchanged) ─────────────────────────────────
138
+ for metric_name in ['dice', 'precision', 'recall', 'iou', 'specificity', 'hd95']:
139
+ metric_data = metrics_data['metrics'][metric_name]
140
+
141
+ for class_idx in range(3):
142
+ if class_idx != 0:
143
+ row[f'{metric_name.upper()}_class_{class_idx}'] = metric_data.get(f'class_{class_idx}')
144
+
145
+ row[f'{metric_name.upper()}_mean'] = metric_data.get('mean')
146
+
147
+ # ── Lesion-level metrics (new — R1C7) ────────────────────────────────
148
+ lesion_data = metrics_data['metrics'].get('lesion', None)
149
+ if lesion_data is not None:
150
+ for class_idx in range(2): # foreground classes only
151
+ key = f'class_{class_idx}'
152
+ cls = lesion_data.get(key, {})
153
+
154
+ # Scalar rates (averaged across patients in inference script)
155
+ for sk in ['lesion_sensitivity', 'lesion_precision', 'lesion_f1']:
156
+ col = f'LESION_{sk.upper()}_class_{class_idx}'
157
+ row[col] = cls.get(sk)
158
+
159
+ # Integer counts (summed across patients in inference script)
160
+ for ck in ['n_gt_lesions', 'n_pred_lesions', 'tp_lesions', 'fn_lesions', 'fp_lesions']:
161
+ col = f'LESION_{ck.upper()}_class_{class_idx}'
162
+ row[col] = cls.get(ck)
163
+
164
+ # Cross-class summary keys produced by aggregate_patient_metrics()
165
+ for sk in ['lesion_sensitivity', 'lesion_precision', 'lesion_f1']:
166
+ row[f'LESION_{sk.upper()}_mean'] = lesion_data.get(f'mean_{sk}')
167
+ for ck in ['n_gt_lesions', 'n_pred_lesions', 'tp_lesions', 'fn_lesions', 'fp_lesions']:
168
+ row[f'LESION_{ck.upper()}_total'] = lesion_data.get(f'total_{ck}')
169
+
170
+ return row
171
+
172
+ def extract_training_info_row(self, results_folder, training_data, best_epoch_analysis):
173
+ """Extract training information including best epoch details."""
174
+ if training_data is None:
175
+ return None
176
+
177
+ row = {
178
+ 'Variant': results_folder['variant'],
179
+ 'Variant_Name': self.variants[results_folder['variant']],
180
+ 'Fold': results_folder['fold']
181
+ }
182
+
183
+ # Try to extract from training_summary.json first
184
+ if isinstance(training_data, dict) and 'best_epoch_selection' in training_data:
185
+ row['Best_Epoch'] = training_data['best_epoch_selection']['overall_best_epoch']
186
+ row['Composite_Score'] = training_data['best_epoch_selection']['composite_score']
187
+ row['Total_Epochs'] = training_data['training_config']['total_epochs']
188
+ # Handle valid_epochs (only for Pix2Pix variants with beta scheduling)
189
+ if 'valid_epochs' in training_data['best_epoch_selection']:
190
+ row['First_Valid_Epoch'] = training_data['best_epoch_selection']['valid_epochs']['first_valid_epoch']
191
+ row['Total_Valid_Epochs'] = training_data['best_epoch_selection']['valid_epochs']['total_valid_epochs']
192
+ else:
193
+ row['First_Valid_Epoch'] = 1
194
+ row['Total_Valid_Epochs'] = training_data['training_config']['total_epochs']
195
+
196
+ # Best epoch metrics
197
+ best_metrics = training_data['best_epoch_metrics']
198
+ row['Best_Epoch_Val_Loss'] = best_metrics['val_loss']
199
+ row['Best_Epoch_Dice_Ventricles'] = best_metrics['dice']['class_1']
200
+ row['Best_Epoch_Dice_Abnormal_WMH'] = best_metrics['dice'].get('class_2', None)
201
+ row['Best_Epoch_Dice_Mean'] = best_metrics['dice']['mean']
202
+
203
+ # Priority metrics
204
+ row['Best_Abnormal_Epoch'] = training_data['priority_metrics']['abnormal_wmh']['best_epoch']
205
+ row['Best_Abnormal_Dice'] = training_data['priority_metrics']['abnormal_wmh']['best_dice']
206
+ row['Best_Ventricles_Epoch'] = training_data['priority_metrics']['ventricles']['best_epoch']
207
+ row['Best_Ventricles_Dice'] = training_data['priority_metrics']['ventricles']['best_dice']
208
+
209
+ # Fallback to best_epoch_analysis.json
210
+ elif best_epoch_analysis is not None:
211
+ row['Best_Epoch'] = best_epoch_analysis['best_overall_epoch']
212
+ row['Composite_Score'] = best_epoch_analysis['composite_score']
213
+ row['Total_Epochs'] = best_epoch_analysis['total_epochs']
214
+ row['First_Valid_Epoch'] = best_epoch_analysis['first_valid_epoch']
215
+ row['Total_Valid_Epochs'] = best_epoch_analysis['total_valid_epochs']
216
+
217
+ # Best epoch metrics
218
+ best_metrics = best_epoch_analysis['best_epoch_metrics']
219
+ row['Best_Epoch_Val_Loss'] = best_metrics['val_loss']
220
+ row['Best_Epoch_Dice_Ventricles'] = best_metrics['dice']['class_1']
221
+ row['Best_Epoch_Dice_Abnormal_WMH'] = best_metrics['dice'].get('class_2', None)
222
+ row['Best_Epoch_Dice_Mean'] = best_metrics['dice']['mean']
223
+
224
+ # Priority metrics
225
+ row['Best_Abnormal_Epoch'] = best_epoch_analysis['best_abnormal_epoch']
226
+ row['Best_Abnormal_Dice'] = best_epoch_analysis['best_abnormal_dice']
227
+ row['Best_Ventricles_Epoch'] = best_epoch_analysis['best_ventricles_epoch']
228
+ row['Best_Ventricles_Dice'] = best_epoch_analysis['best_ventricles_dice']
229
+
230
+ # Legacy fallback to history.json
231
+ elif isinstance(training_data, dict) and 'val_metrics' in training_data:
232
+ if 'best_epoch_analysis' in training_data:
233
+ analysis = training_data['best_epoch_analysis']
234
+ row['Best_Epoch'] = analysis['best_overall_epoch']
235
+ row['Composite_Score'] = analysis.get('composite_score', None)
236
+ else:
237
+ # Find best validation dice
238
+ val_dice_list = [m['dice']['mean'] for m in training_data['val_metrics']]
239
+ row['Best_Epoch'] = val_dice_list.index(max(val_dice_list)) + 1
240
+ row['Composite_Score'] = max(val_dice_list)
241
+
242
+ row['Total_Epochs'] = len(training_data['val_metrics'])
243
+
244
+ return row
245
+
246
+ def create_test_metrics_summary(self):
247
+ """Create a comprehensive summary of test metrics."""
248
+ results_folders = self.find_results_folders()
249
+
250
+ if not results_folders:
251
+ print("No results folders found!")
252
+ return None
253
+
254
+ rows = []
255
+ for folder in results_folders:
256
+ metrics_data = self.load_test_metrics(folder)
257
+ row = self.extract_test_metrics_row(folder, metrics_data)
258
+ if row is not None:
259
+ rows.append(row)
260
+
261
+ df = pd.DataFrame(rows)
262
+ df = df.sort_values(['Variant', 'Fold']).reset_index(drop=True)
263
+
264
+ return df
265
+
266
+ def create_training_summary(self):
267
+ """Create a comprehensive summary of training information."""
268
+ results_folders = self.find_results_folders()
269
+
270
+ if not results_folders:
271
+ print("No results folders found!")
272
+ return None
273
+
274
+ rows = []
275
+ for folder in results_folders:
276
+ training_data = self.load_training_summary(folder)
277
+ best_epoch_analysis = self.load_best_epoch_analysis(folder)
278
+ row = self.extract_training_info_row(folder, training_data, best_epoch_analysis)
279
+ if row is not None:
280
+ rows.append(row)
281
+
282
+ df = pd.DataFrame(rows)
283
+ df = df.sort_values(['Variant', 'Fold']).reset_index(drop=True)
284
+
285
+ return df
286
+
287
+ def create_per_class_summary(self, test_metrics_df):
288
+ """
289
+ Create per-class summary statistics across folds for each variant.
290
+ Includes both voxel-level and lesion-level metrics.
291
+ """
292
+ summaries = []
293
+
294
+ for variant in range(self.num_variants +1):
295
+ variant_data = test_metrics_df[test_metrics_df['Variant'] == variant]
296
+
297
+ if len(variant_data) == 0:
298
+ continue
299
+
300
+ for class_idx in range(3):
301
+ if class_idx == 0:
302
+ continue
303
+
304
+ class_summary = {
305
+ 'Variant': variant,
306
+ 'Variant_Name': self.variants[variant],
307
+ 'Class': class_idx,
308
+ 'Class_Name': self.class_names[class_idx]
309
+ }
310
+
311
+ # Voxel-level metrics
312
+ for metric in ['DICE', 'PRECISION', 'RECALL', 'IOU', 'SPECIFICITY', 'HD95']:
313
+ col_name = f'{metric}_class_{class_idx}'
314
+ if col_name in variant_data.columns:
315
+ values = variant_data[col_name].dropna().values
316
+ class_summary[f'{metric}_mean'] = np.mean(values)
317
+ class_summary[f'{metric}_std'] = np.std(values)
318
+ class_summary[f'{metric}_min'] = np.min(values)
319
+ class_summary[f'{metric}_max'] = np.max(values)
320
+
321
+ # Lesion-level scalar metrics (mean ± std across folds)
322
+ for sk in ['LESION_SENSITIVITY', 'LESION_PRECISION', 'LESION_F1']:
323
+ col_name = f'LESION_{sk}_class_{class_idx}'
324
+ if col_name in variant_data.columns:
325
+ values = variant_data[col_name].dropna().values
326
+ class_summary[f'{sk}_mean'] = np.mean(values) if len(values) else np.nan
327
+ class_summary[f'{sk}_std'] = np.std(values) if len(values) else np.nan
328
+
329
+ # Lesion-level count metrics (sum across folds — total pool)
330
+ for ck in ['N_GT_LESIONS', 'N_PRED_LESIONS', 'TP_LESIONS', 'FN_LESIONS', 'FP_LESIONS']:
331
+ col_name = f'LESION_{ck}_class_{class_idx}'
332
+ if col_name in variant_data.columns:
333
+ values = variant_data[col_name].dropna().values
334
+ class_summary[f'LESION_{ck}_total'] = int(np.sum(values)) if len(values) else 0
335
+
336
+ summaries.append(class_summary)
337
+
338
+ df = pd.DataFrame(summaries)
339
+ return df
340
+
341
+ def create_variant_comparison(self, test_metrics_df):
342
+ """
343
+ Create a variant comparison table with mean ± std across folds.
344
+ Includes both voxel-level and lesion-level metrics.
345
+ """
346
+ comparisons = []
347
+
348
+ for variant in range(self.num_variants + 1):
349
+ variant_data = test_metrics_df[test_metrics_df['Variant'] == variant]
350
+
351
+ if len(variant_data) == 0:
352
+ continue
353
+
354
+ comparison = {
355
+ 'Variant': variant,
356
+ 'Variant_Name': self.variants[variant],
357
+ 'N_Folds': len(variant_data)
358
+ }
359
+
360
+ # ── Voxel-level metrics ──────────────────────────────────────────
361
+ for metric in ['DICE', 'PRECISION', 'RECALL', 'IOU', 'SPECIFICITY', 'HD95']:
362
+ # Overall mean across classes
363
+ col_name = f'{metric}_mean'
364
+ if col_name in variant_data.columns:
365
+ values = variant_data[col_name].dropna().values
366
+ comparison[f'{metric}_Mean'] = np.mean(values)
367
+ comparison[f'{metric}_Std'] = np.std(values)
368
+
369
+ # Per-class (Ventricles=1, Abnormal_WMH=2)
370
+ for class_idx in [1, 2]:
371
+ col_name = f'{metric}_class_{class_idx}'
372
+ if col_name in variant_data.columns:
373
+ values = variant_data[col_name].dropna().values
374
+ comparison[f'{metric}_Class{class_idx}_Mean'] = np.mean(values)
375
+ comparison[f'{metric}_Class{class_idx}_Std'] = np.std(values)
376
+
377
+ # ── Lesion-level scalar metrics (mean ± std across folds) ────────
378
+ for sk_suffix in ['LESION_SENSITIVITY', 'LESION_PRECISION', 'LESION_F1']:
379
+ # Cross-class mean
380
+ col_name = f'LESION_{sk_suffix}_mean'
381
+ if col_name in variant_data.columns:
382
+ values = variant_data[col_name].dropna().values
383
+ comparison[f'{sk_suffix}_Mean'] = np.mean(values) if len(values) else np.nan
384
+ comparison[f'{sk_suffix}_Std'] = np.std(values) if len(values) else np.nan
385
+
386
+ # Per-class
387
+ for class_idx in [2]:
388
+ col_name = f'LESION_{sk_suffix}_class_{class_idx}'
389
+ if col_name in variant_data.columns:
390
+ values = variant_data[col_name].dropna().values
391
+ comparison[f'{sk_suffix}_Class{class_idx}_Mean'] = np.mean(values) if len(values) else np.nan
392
+ comparison[f'{sk_suffix}_Class{class_idx}_Std'] = np.std(values) if len(values) else np.nan
393
+
394
+ # ── Lesion-level count metrics (sum across folds) ────────────────
395
+ for ck in ['N_GT_LESIONS', 'N_PRED_LESIONS', 'TP_LESIONS', 'FN_LESIONS', 'FP_LESIONS']:
396
+ # Total across all classes
397
+ col_name = f'LESION_{ck}_total'
398
+ if col_name in variant_data.columns:
399
+ values = variant_data[col_name].dropna().values
400
+ comparison[f'LESION_{ck}_Total'] = int(np.sum(values)) if len(values) else 0
401
+
402
+ # Per-class totals
403
+ for class_idx in [2]:
404
+ col_name = f'LESION_{ck}_class_{class_idx}'
405
+ if col_name in variant_data.columns:
406
+ values = variant_data[col_name].dropna().values
407
+ comparison[f'LESION_{ck}_Class{class_idx}_Total'] = int(np.sum(values)) if len(values) else 0
408
+
409
+ comparisons.append(comparison)
410
+
411
+ df = pd.DataFrame(comparisons)
412
+ return df
413
+
414
+ def create_training_comparison(self, training_df):
415
+ """Create training comparison showing convergence patterns."""
416
+ if training_df is None:
417
+ return None
418
+
419
+ comparisons = []
420
+
421
+ for variant in range(self.num_variants + 1):
422
+ variant_data = training_df[training_df['Variant'] == variant]
423
+
424
+ if len(variant_data) == 0:
425
+ continue
426
+
427
+ comparison = {
428
+ 'Variant': variant,
429
+ 'Variant_Name': self.variants[variant],
430
+ 'N_Folds': len(variant_data)
431
+ }
432
+
433
+ # Best epoch statistics
434
+ if 'Best_Epoch' in variant_data.columns:
435
+ comparison['Best_Epoch_Mean'] = np.mean(variant_data['Best_Epoch'].values)
436
+ comparison['Best_Epoch_Std'] = np.std(variant_data['Best_Epoch'].values)
437
+ comparison['Best_Epoch_Min'] = np.min(variant_data['Best_Epoch'].values)
438
+ comparison['Best_Epoch_Max'] = np.max(variant_data['Best_Epoch'].values)
439
+
440
+ # Composite score statistics
441
+ if 'Composite_Score' in variant_data.columns:
442
+ comparison['Composite_Score_Mean'] = np.mean(variant_data['Composite_Score'].dropna().values)
443
+ comparison['Composite_Score_Std'] = np.std(variant_data['Composite_Score'].dropna().values)
444
+
445
+ # Validation metrics at best epoch
446
+ for metric_col in ['Best_Epoch_Val_Loss', 'Best_Epoch_Dice_Mean',
447
+ 'Best_Epoch_Dice_Ventricles', 'Best_Epoch_Dice_Abnormal_WMH']:
448
+ if metric_col in variant_data.columns:
449
+ values = variant_data[metric_col].dropna().values
450
+ if len(values) > 0:
451
+ comparison[f'{metric_col}_Mean'] = np.mean(values)
452
+ comparison[f'{metric_col}_Std'] = np.std(values)
453
+
454
+ comparisons.append(comparison)
455
+
456
+ df = pd.DataFrame(comparisons)
457
+ return df
458
+
459
+ def generate_all_summaries(self, output_dir='./folds_results'):
460
+ """Generate all summary CSV files."""
461
+ output_path = Path(output_dir)
462
+ output_path.mkdir(exist_ok=True)
463
+
464
+ print("=" * 80)
465
+ print("RESULTS AGGREGATION STARTED")
466
+ print("=" * 80)
467
+
468
+ # 1. Test Metrics Summary (all variants, all folds)
469
+ print("\n1. Generating test metrics summary...")
470
+ test_metrics_df = self.create_test_metrics_summary()
471
+ if test_metrics_df is not None:
472
+ output_file = output_path / 'test_metrics_all_variants_folds.csv'
473
+ test_metrics_df.to_csv(output_file, index=False)
474
+ print(f" ✓ Saved: {output_file}")
475
+ print(f" - Shape: {test_metrics_df.shape}")
476
+
477
+ # 2. Training Summary
478
+ print("\n2. Generating training summary...")
479
+ training_df = self.create_training_summary()
480
+ if training_df is not None:
481
+ output_file = output_path / 'training_info_all_variants_folds.csv'
482
+ training_df.to_csv(output_file, index=False)
483
+ print(f" ✓ Saved: {output_file}")
484
+ print(f" - Shape: {training_df.shape}")
485
+
486
+ # 3. Per-Class Summary
487
+ print("\n3. Generating per-class summary...")
488
+ per_class_df = None
489
+ if test_metrics_df is not None:
490
+ per_class_df = self.create_per_class_summary(test_metrics_df)
491
+ output_file = output_path / 'per_class_summary.csv'
492
+ per_class_df.to_csv(output_file, index=False)
493
+ print(f" ✓ Saved: {output_file}")
494
+ print(f" - Shape: {per_class_df.shape}")
495
+
496
+ # 4. Variant Comparison (Test Metrics)
497
+ print("\n4. Generating variant comparison (test metrics)...")
498
+ variant_comparison_df = None
499
+ if test_metrics_df is not None:
500
+ variant_comparison_df = self.create_variant_comparison(test_metrics_df)
501
+ output_file = output_path / 'variant_comparison_test.csv'
502
+ variant_comparison_df.to_csv(output_file, index=False)
503
+ print(f" ✓ Saved: {output_file}")
504
+ print(f" - Shape: {variant_comparison_df.shape}")
505
+
506
+ # 5. Variant Comparison (Training)
507
+ print("\n5. Generating variant comparison (training)...")
508
+ training_comparison_df = None
509
+ if training_df is not None:
510
+ training_comparison_df = self.create_training_comparison(training_df)
511
+ if training_comparison_df is not None:
512
+ output_file = output_path / 'variant_comparison_training.csv'
513
+ training_comparison_df.to_csv(output_file, index=False)
514
+ print(f" ✓ Saved: {output_file}")
515
+ print(f" - Shape: {training_comparison_df.shape}")
516
+
517
+ print("\n" + "=" * 80)
518
+ print("AGGREGATION COMPLETE")
519
+ print("=" * 80)
520
+
521
+ return {
522
+ 'test_metrics': test_metrics_df,
523
+ 'training_info': training_df,
524
+ 'per_class': per_class_df,
525
+ 'variant_comparison_test': variant_comparison_df,
526
+ 'variant_comparison_training': training_comparison_df
527
+ }
528
+
529
+ def print_summary_statistics(self, dfs):
530
+ """Print summary statistics to console."""
531
+ print("\n" + "=" * 80)
532
+ print("SUMMARY STATISTICS")
533
+ print("=" * 80)
534
+
535
+ if dfs['variant_comparison_test'] is not None:
536
+
537
+ # ── Voxel-level Dice ─────────────────────────────────────────────
538
+ print("\n📊 TEST DICE SCORES (Mean ± Std) across folds:")
539
+ print("-" * 80)
540
+ for _, row in dfs['variant_comparison_test'].iterrows():
541
+ print(f"\nVariant {row['Variant']}: {row['Variant_Name']}")
542
+ print(f" Overall: {row['DICE_Mean']:.4f} ± {row['DICE_Std']:.4f}")
543
+ print(f" Ventricles: {row['DICE_Class1_Mean']:.4f} ± {row['DICE_Class1_Std']:.4f}")
544
+ print(f" Abnormal WMH: {row['DICE_Class2_Mean']:.4f} ± {row['DICE_Class2_Std']:.4f}")
545
+
546
+ # ── Lesion-level metrics ─────────────────────────────────────────
547
+ lesion_cols_present = any(
548
+ col.startswith('LESION_') for col in dfs['variant_comparison_test'].columns
549
+ )
550
+ if lesion_cols_present:
551
+ print("\n\n🔬 LESION-LEVEL METRICS (Mean ± Std) across folds:")
552
+ print("-" * 80)
553
+ for _, row in dfs['variant_comparison_test'].iterrows():
554
+ print(f"\nVariant {row['Variant']}: {row['Variant_Name']}")
555
+
556
+ # Per-class
557
+ for class_idx, class_name in [(2, 'Abnormal WMH')]:
558
+ sens_col = f'LESION_LESION_SENSITIVITY_Class{class_idx}_Mean'
559
+ prec_col = f'LESION_LESION_PRECISION_Class{class_idx}_Mean'
560
+ f1_col = f'LESION_LESION_F1_Class{class_idx}_Mean'
561
+ tp_col = f'LESION_TP_LESIONS_Class{class_idx}_Total'
562
+ fp_col = f'LESION_FP_LESIONS_Class{class_idx}_Total'
563
+ fn_col = f'LESION_FN_LESIONS_Class{class_idx}_Total'
564
+ gt_col = f'LESION_N_GT_LESIONS_Class{class_idx}_Total'
565
+
566
+ print(f" [{class_name}]")
567
+ if sens_col in row:
568
+ s_m = f"{row[sens_col]:.4f}" if pd.notna(row.get(sens_col)) else 'N/A'
569
+ s_s = f"{row.get(f'LESION_LESION_SENSITIVITY_Class{class_idx}_Std', float('nan')):.4f}"
570
+ p_m = f"{row[prec_col]:.4f}" if pd.notna(row.get(prec_col)) else 'N/A'
571
+ p_s = f"{row.get(f'LESION_LESION_PRECISION_Class{class_idx}_Std', float('nan')):.4f}"
572
+ f_m = f"{row[f1_col]:.4f}" if pd.notna(row.get(f1_col)) else 'N/A'
573
+ f_s = f"{row.get(f'LESION_LESION_F1_Class{class_idx}_Std', float('nan')):.4f}"
574
+ print(f" Sensitivity : {s_m} ± {s_s}")
575
+ print(f" Precision : {p_m} ± {p_s}")
576
+ print(f" F1 : {f_m} ± {f_s}")
577
+ if gt_col in row:
578
+ print(f" GT Lesions : {int(row.get(gt_col, 0))} "
579
+ f"TP: {int(row.get(tp_col, 0))} "
580
+ f"FP: {int(row.get(fp_col, 0))} "
581
+ f"FN: {int(row.get(fn_col, 0))}")
582
+
583
+ if dfs['variant_comparison_training'] is not None:
584
+ print("\n\n🏆 TRAINING CONVERGENCE:")
585
+ print("-" * 80)
586
+ for _, row in dfs['variant_comparison_training'].iterrows():
587
+ print(f"\nVariant {row['Variant']}: {row['Variant_Name']}")
588
+ if 'Best_Epoch_Mean' in row:
589
+ print(f" Best Epoch: {row['Best_Epoch_Mean']:.1f} ± {row['Best_Epoch_Std']:.1f}")
590
+ if 'Best_Epoch_Dice_Abnormal_WMH_Mean' in row:
591
+ print(f" Val Abnormal: {row['Best_Epoch_Dice_Abnormal_WMH_Mean']:.4f} ± {row['Best_Epoch_Dice_Abnormal_WMH_Std']:.4f}")
592
+
593
+
594
+ # Main execution
595
+ if __name__ == "__main__":
596
+ # Initialize aggregator
597
+ aggregator = ResultsAggregator(base_dir='./')
598
+
599
+ # Generate all summaries
600
+ dfs = aggregator.generate_all_summaries(output_dir='./folds_results_zscore2_all')
601
+
602
+ # Print summary statistics
603
+ aggregator.print_summary_statistics(dfs)
604
+
605
+ print("\n✓ All CSV files have been generated in './folds_results_zscore2_all' directory")
606
+ print("\nGenerated files:")
607
+ print(" 1. test_metrics_all_variants_folds.csv - Complete test metrics (voxel + lesion level)")
608
+ print(" 2. training_info_all_variants_folds.csv - Training convergence info")
609
+ print(" 3. per_class_summary.csv - Per-class statistics (voxel + lesion level)")
610
+ print(" 4. variant_comparison_test.csv - Test metrics comparison (voxel + lesion level)")
611
+ print(" 5. variant_comparison_training.csv - Training comparison")
models/for_WMH_Vent/model_training_scripts/p4_inference.py ADDED
@@ -0,0 +1,1146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ P4 Article - Inference Script for ventricles and WMH segmentation task
3
+
4
+ Developer:
5
+ Mahdi Bashiri Bawil
6
+ """
7
+
8
+ import tensorflow as tf
9
+ import os
10
+ from collections import defaultdict
11
+ import numpy as np
12
+ import matplotlib.pyplot as plt
13
+ from pathlib import Path
14
+ from tqdm import tqdm
15
+ import json
16
+ import nibabel as nib
17
+ import seaborn as sns
18
+ from sklearn.metrics import confusion_matrix, cohen_kappa_score, classification_report
19
+
20
+ from scipy.spatial.distance import directed_hausdorff
21
+ from scipy.ndimage import distance_transform_edt
22
+ from scipy.spatial.distance import cdist
23
+ from scipy.ndimage import binary_erosion
24
+ from scipy.ndimage import label as nd_label
25
+
26
+ from unet_model import build_unet_3class # must be updated with the actual used model for traininig
27
+
28
+ # Import data loader
29
+ from p4_data_loader import DataConfig, P2DataLoader
30
+
31
+ # Error analysis
32
+ from p4_error_analysis import run_error_analysis
33
+
34
+
35
+ print("TensorFlow Version:", tf.__version__)
36
+
37
+ ###################### GPU Configuration ######################
38
+
39
+ # Configure GPU memory growth
40
+ physical_devices = tf.config.list_physical_devices('GPU')
41
+ if physical_devices:
42
+ try:
43
+ for device in physical_devices:
44
+ tf.config.experimental.set_memory_growth(device, True)
45
+ print("✅ GPU memory growth enabled")
46
+ print(f" Available GPUs: {len(physical_devices)}")
47
+ except RuntimeError as e:
48
+ print(f"GPU configuration error: {e}")
49
+ else:
50
+ print("⚠️ No GPU detected - inference will be slow")
51
+
52
+
53
+ ###################### Inference Configuration ######################
54
+
55
+ class InferenceConfig:
56
+ """Configuration for inference"""
57
+
58
+ def __init__(self,
59
+ variant: int = 5,
60
+ preprocessing: str = 'standard',
61
+ class_scenario: str = '4class',
62
+ fold_id: int = 0,
63
+ model_name: str = 'best_dice_generator.h5',
64
+ architecture_name: str = 'unet'
65
+ ):
66
+
67
+ # Experiment identification
68
+ self.variant = variant
69
+ self.preprocessing = preprocessing
70
+ self.class_scenario = class_scenario
71
+ self.fold_id = fold_id
72
+ self.model_name = model_name
73
+ self.architecture_name = architecture_name
74
+
75
+ # Number of classes
76
+ self.num_classes = 3 if class_scenario == '3class' else 4
77
+
78
+ # Class names
79
+ if self.num_classes == 4:
80
+ self.class_names = ['Background', 'Ventricles', 'Normal_WMH', 'Abnormal_WMH']
81
+ elif self.num_classes == 3:
82
+ self.class_names = ['Background', 'Ventricles', 'Abnormal_WMH']
83
+
84
+ # Image dimensions
85
+ self.batch_size = 1 # Use batch_size=1 for inference
86
+ self.img_width = 256
87
+ self.img_height = 256
88
+
89
+ # Paths
90
+ self.results_dir = Path(f"results_fold_{fold_id}_var_{variant}_zscore2")
91
+ self.models_dir = self.results_dir / "models" / f"{preprocessing}_{class_scenario}"
92
+ self.checkpoint_dir = self.models_dir / f"fold_{fold_id}"
93
+
94
+ # Output directories
95
+ self.inference_dir = self.results_dir / "inference_all_test" / f"{preprocessing}_{class_scenario}"
96
+ # self.predictions_dir = self.inference_dir / "predictions"
97
+ self.visualizations_dir = self.inference_dir / "visualizations"
98
+ self.metrics_dir = self.inference_dir / "metrics"
99
+
100
+ # Create directories
101
+ # self.predictions_dir.mkdir(parents=True, exist_ok=True)
102
+ self.visualizations_dir.mkdir(parents=True, exist_ok=True)
103
+ self.metrics_dir.mkdir(parents=True, exist_ok=True)
104
+
105
+ # Model path
106
+ self.model_path = self.checkpoint_dir / self.model_name
107
+
108
+ # Check if model exists
109
+ if not self.model_path.exists():
110
+ raise FileNotFoundError(f"Model not found: {self.model_path}")
111
+
112
+ print(f"\n{'='*70}")
113
+ print(f"INFERENCE CONFIGURATION")
114
+ print(f"{'='*70}")
115
+ print(f"Variant: {self.variant}")
116
+ print(f"Preprocessing: {self.preprocessing}")
117
+ print(f"Class scenario: {self.class_scenario} ({self.num_classes} classes)")
118
+ print(f"Fold: {self.fold_id}")
119
+ print(f"Architecture: {self.architecture_name}")
120
+ print(f"Model: {self.model_name}")
121
+ print(f"Model path: {self.model_path}")
122
+ print(f"Output directory: {self.inference_dir}")
123
+ print(f"{'='*70}\n")
124
+
125
+
126
+ ###################### Utility Functions ######################
127
+
128
+ def prepare_input(paired_input):
129
+ """
130
+ Extract and normalize FLAIR from paired input
131
+
132
+ Args:
133
+ paired_input: (bs, 256, 512, 1) with FLAIR + mask
134
+
135
+ Returns:
136
+ flair_normalized: FLAIR normalized to [-1, 1]
137
+ """
138
+ # Extract FLAIR (left half)
139
+ flair_normalized = paired_input[:, :, :256, :]
140
+ return flair_normalized
141
+
142
+ def compute_hd95(mask1, mask2):
143
+ """
144
+ Compute 95th percentile Hausdorff Distance between two binary masks
145
+
146
+ Args:
147
+ mask1: Binary mask 1
148
+ mask2: Binary mask 2
149
+
150
+ Returns:
151
+ HD95 value in pixels
152
+ """
153
+ # Get boundary points
154
+ if not np.any(mask1) or not np.any(mask2):
155
+ return np.nan
156
+
157
+ # Compute distance transforms
158
+ dt1 = distance_transform_edt(~mask1.astype(bool))
159
+ dt2 = distance_transform_edt(~mask2.astype(bool))
160
+
161
+ # Get surface points
162
+ surface1 = mask1.astype(bool) & (dt1 <= 1)
163
+ surface2 = mask2.astype(bool) & (dt2 <= 1)
164
+
165
+ if not np.any(surface1) or not np.any(surface2):
166
+ return np.nan
167
+
168
+ # Get coordinates of surface points
169
+ coords1 = np.argwhere(surface1)
170
+ coords2 = np.argwhere(surface2)
171
+
172
+ # Compute distances from surface1 to surface2
173
+ distances1 = np.min(np.sqrt(np.sum((coords1[:, np.newaxis, :] - coords2[np.newaxis, :, :]) ** 2, axis=2)), axis=1)
174
+ # Compute distances from surface2 to surface1
175
+ distances2 = np.min(np.sqrt(np.sum((coords2[:, np.newaxis, :] - coords1[np.newaxis, :, :]) ** 2, axis=2)), axis=1)
176
+
177
+ # Combine distances
178
+ all_distances = np.concatenate([distances1, distances2])
179
+
180
+ # Return 95th percentile
181
+ return np.percentile(all_distances, 95)
182
+
183
+ def compute_hd95_3d(mask1, mask2):
184
+ """
185
+ Compute 95th percentile Hausdorff Distance for 3D volume
186
+ Uses only surface voxels for efficiency
187
+
188
+ Args:
189
+ mask1: Binary mask (N, H, W)
190
+ mask2: Binary mask (N, H, W)
191
+
192
+ Returns:
193
+ HD95 value in pixels
194
+ """
195
+ if not np.any(mask1) or not np.any(mask2):
196
+ return np.nan
197
+
198
+ # Extract surface voxels only (border voxels)
199
+ from scipy.ndimage import binary_erosion
200
+
201
+ # Surface = original mask minus eroded mask
202
+ surface1 = mask1.astype(bool) & ~binary_erosion(mask1.astype(bool))
203
+ surface2 = mask2.astype(bool) & ~binary_erosion(mask2.astype(bool))
204
+
205
+ # Get surface coordinates
206
+ coords1 = np.argwhere(surface1)
207
+ coords2 = np.argwhere(surface2)
208
+
209
+ if len(coords1) == 0 or len(coords2) == 0:
210
+ return np.nan
211
+
212
+ # Subsample if still too large (>10k points each)
213
+ max_points = 10000
214
+ if len(coords1) > max_points:
215
+ idx1 = np.random.choice(len(coords1), max_points, replace=False)
216
+ coords1 = coords1[idx1]
217
+ if len(coords2) > max_points:
218
+ idx2 = np.random.choice(len(coords2), max_points, replace=False)
219
+ coords2 = coords2[idx2]
220
+
221
+ # Compute distances
222
+ distances1 = np.min(cdist(coords1, coords2, metric='euclidean'), axis=1)
223
+ distances2 = np.min(cdist(coords2, coords1, metric='euclidean'), axis=1)
224
+
225
+ # Combine all distances
226
+ all_distances = np.concatenate([distances1, distances2])
227
+
228
+ # Return 95th percentile
229
+ return np.percentile(all_distances, 95)
230
+
231
+
232
+ def compute_lesion_level_metrics(gt_volume, pred_volume, iou_threshold=0.1):
233
+ """
234
+ Compute lesion-level (instance-level) metrics by treating each connected
235
+ component in the GT as an individual lesion.
236
+
237
+ A GT lesion is considered DETECTED if its overlap (IoU) with any single
238
+ predicted component exceeds `iou_threshold`.
239
+ A predicted component is a TRUE POSITIVE if it overlaps any GT lesion
240
+ above threshold, otherwise it is a FALSE POSITIVE lesion.
241
+
242
+ Args:
243
+ gt_volume : binary 3-D numpy array (S, H, W) — ground truth for ONE class
244
+ pred_volume : binary 3-D numpy array (S, H, W) — prediction for ONE class
245
+ iou_threshold: minimum IoU to count a GT lesion as detected (default 0.1)
246
+
247
+ Returns:
248
+ dict with keys:
249
+ n_gt_lesions : total number of GT lesions
250
+ n_pred_lesions : total number of predicted lesion clusters
251
+ tp_lesions : GT lesions that were detected
252
+ fn_lesions : GT lesions that were missed
253
+ fp_lesions : predicted clusters with no GT overlap
254
+ lesion_sensitivity: tp_lesions / n_gt_lesions
255
+ lesion_precision : tp_lesions / n_pred_lesions
256
+ lesion_f1 : harmonic mean of lesion sensitivity and precision
257
+ """
258
+ gt_bin = gt_volume.astype(bool)
259
+ pred_bin = pred_volume.astype(bool)
260
+
261
+ # Label connected components
262
+ gt_labeled, n_gt = nd_label(gt_bin)
263
+ pred_labeled, n_pred = nd_label(pred_bin)
264
+
265
+ tp_lesions = 0
266
+ detected_pred_ids = set()
267
+
268
+ for gt_id in range(1, n_gt + 1):
269
+ gt_mask = (gt_labeled == gt_id)
270
+ # Find all predicted components that overlap this GT lesion
271
+ overlapping_pred_ids = np.unique(pred_labeled[gt_mask])
272
+ overlapping_pred_ids = overlapping_pred_ids[overlapping_pred_ids > 0]
273
+
274
+ detected = False
275
+ for pred_id in overlapping_pred_ids:
276
+ pred_mask = (pred_labeled == pred_id)
277
+ intersection = np.logical_and(gt_mask, pred_mask).sum()
278
+ union = np.logical_or(gt_mask, pred_mask).sum()
279
+ iou = intersection / (union + 1e-7)
280
+ if iou >= iou_threshold:
281
+ detected = True
282
+ detected_pred_ids.add(pred_id)
283
+
284
+ if detected:
285
+ tp_lesions += 1
286
+
287
+ fn_lesions = n_gt - tp_lesions
288
+ fp_lesions = n_pred - len(detected_pred_ids)
289
+
290
+ lesion_sensitivity = tp_lesions / (n_gt + 1e-7)
291
+ lesion_precision = tp_lesions / (n_pred + 1e-7) if n_pred > 0 else 0.0
292
+ lesion_f1 = (2 * lesion_sensitivity * lesion_precision /
293
+ (lesion_sensitivity + lesion_precision + 1e-7))
294
+
295
+ return {
296
+ 'n_gt_lesions' : int(n_gt),
297
+ 'n_pred_lesions' : int(n_pred),
298
+ 'tp_lesions' : int(tp_lesions),
299
+ 'fn_lesions' : int(fn_lesions),
300
+ 'fp_lesions' : int(fp_lesions),
301
+ 'lesion_sensitivity' : float(lesion_sensitivity),
302
+ 'lesion_precision' : float(lesion_precision),
303
+ 'lesion_f1' : float(lesion_f1),
304
+ }
305
+
306
+
307
+ def compute_metrics_from_predictions(y_true, y_pred, num_classes, exclude_class=None):
308
+ """
309
+ Compute comprehensive metrics from predictions
310
+
311
+ Args:
312
+ y_true: Ground truth class labels (N, H, W)
313
+ y_pred: Predicted class labels (N, H, W)
314
+ num_classes: Number of classes
315
+ exclude_class: Class to exclude from metrics (e.g., 2 for Normal_WMH in 4-class)
316
+
317
+ Returns:
318
+ Dictionary containing metrics
319
+ """
320
+ # Convert to one-hot
321
+ y_true_onehot = tf.one_hot(y_true, depth=num_classes, dtype=tf.float32)
322
+ y_pred_onehot = tf.one_hot(y_pred, depth=num_classes, dtype=tf.float32)
323
+
324
+ # Flatten spatial dimensions
325
+ y_true_flat = tf.reshape(y_true_onehot, [-1, num_classes])
326
+ y_pred_flat = tf.reshape(y_pred_onehot, [-1, num_classes])
327
+
328
+ # Convert to numpy
329
+ y_true_np = y_true_flat.numpy()
330
+ y_pred_np = y_pred_flat.numpy()
331
+
332
+ metrics = {
333
+ 'dice': {},
334
+ 'precision': {},
335
+ 'recall': {},
336
+ 'iou': {},
337
+ 'specificity': {},
338
+ 'hd95': {},
339
+ 'TP': {}
340
+ }
341
+
342
+ classes_to_evaluate = [c for c in range(num_classes) if c != exclude_class]
343
+
344
+ for class_idx in classes_to_evaluate:
345
+ # Extract binary masks for this class
346
+ true_class = y_true_np[:, class_idx]
347
+ pred_class = y_pred_np[:, class_idx]
348
+
349
+ # Compute confusion matrix elements
350
+ TP = np.sum((true_class == 1) & (pred_class == 1))
351
+ FP = np.sum((true_class == 0) & (pred_class == 1))
352
+ FN = np.sum((true_class == 1) & (pred_class == 0))
353
+ TN = np.sum((true_class == 0) & (pred_class == 0))
354
+
355
+ # Dice Score: 2*TP / (2*TP + FP + FN)
356
+ dice = (2 * TP) / (2 * TP + FP + FN + 1e-7)
357
+
358
+ # Precision: TP / (TP + FP)
359
+ precision = TP / (TP + FP + 1e-7)
360
+
361
+ # Recall (Sensitivity): TP / (TP + FN)
362
+ recall = TP / (TP + FN + 1e-7)
363
+
364
+ # IoU (Jaccard): TP / (TP + FP + FN)
365
+ iou = TP / (TP + FP + FN + 1e-7)
366
+
367
+ # Specificity: TN / (TN + FP)
368
+ specificity = TN / (TN + FP + 1e-7)
369
+
370
+ # HD95: Hausdorff Distance 95th percentile
371
+ # Compute on entire volume (all samples combined) for fairness
372
+ true_class_volume = y_true_np[:, class_idx].reshape(y_true.shape[0], y_true.shape[1], y_true.shape[2])
373
+ pred_class_volume = y_pred_np[:, class_idx].reshape(y_pred.shape[0], y_pred.shape[1], y_pred.shape[2])
374
+
375
+ hd95_value = compute_hd95_3d(true_class_volume, pred_class_volume)
376
+
377
+ metrics['dice'][f'class_{class_idx}'] = float(dice)
378
+ metrics['precision'][f'class_{class_idx}'] = float(precision)
379
+ metrics['recall'][f'class_{class_idx}'] = float(recall)
380
+ metrics['iou'][f'class_{class_idx}'] = float(iou)
381
+ metrics['specificity'][f'class_{class_idx}'] = float(specificity)
382
+ metrics['hd95'][f'class_{class_idx}'] = float(hd95_value)
383
+ metrics['TP'][f'class_{class_idx}'] = float(TP)
384
+
385
+ # Compute mean metrics (excluding the excluded class)
386
+ for metric_name in ['dice', 'precision', 'recall', 'iou', 'specificity', 'hd95', 'TP']:
387
+ metrics[metric_name]['mean'] = np.mean([v for v in metrics[metric_name].values()])
388
+
389
+ # --- Lesion-level metrics (connected-component analysis) ---
390
+ metrics['lesion'] = {}
391
+ for class_idx in classes_to_evaluate:
392
+ if class_idx <= 1: # skip background and ventricles
393
+ continue
394
+ true_vol = y_true_np[:, class_idx].reshape(y_true.shape)
395
+ pred_vol = y_pred_np[:, class_idx].reshape(y_pred.shape)
396
+ metrics['lesion'][f'class_{class_idx}'] = compute_lesion_level_metrics(
397
+ true_vol, pred_vol, iou_threshold=0.1
398
+ )
399
+
400
+ return metrics
401
+
402
+
403
+ # def aggregate_patient_metrics(per_patient_metrics, num_classes):
404
+ # """
405
+ # Returns both a flat structure (compatible with original overall_metrics)
406
+ # and an extended structure with std/n for richer reporting.
407
+ # """
408
+ # flat_metrics = {m: {} for m in ['dice', 'precision', 'recall', 'iou', 'specificity', 'hd95', 'TP']}
409
+ # rich_metrics = {m: {} for m in ['dice', 'precision', 'recall', 'iou', 'specificity', 'hd95', 'TP']}
410
+
411
+ # metric_names = ['dice', 'precision', 'recall', 'iou', 'specificity', 'hd95', 'TP']
412
+
413
+ # for metric_name in metric_names:
414
+ # for class_idx in range(num_classes):
415
+ # if class_idx == 0: continue
416
+
417
+ # key = f'class_{class_idx}'
418
+
419
+ # values = [
420
+ # per_patient_metrics[pid][metric_name][key]
421
+ # for pid in per_patient_metrics
422
+ # if key in per_patient_metrics[pid][metric_name]
423
+ # and not np.isnan(per_patient_metrics[pid][metric_name][key])
424
+ # ]
425
+
426
+ # TP_values = [
427
+ # per_patient_metrics[pid]['TP'][key]
428
+ # for pid in per_patient_metrics
429
+ # if key in per_patient_metrics[pid]['TP']
430
+ # and not np.isnan(per_patient_metrics[pid]['TP'][key])
431
+ # ]
432
+
433
+ # weighted_mean_values = np.sum((np.array(values) * np.array(TP_values)) / np.sum(np.array(TP_values)))
434
+
435
+ # mean_val = float(np.mean(values)) if values else np.nan
436
+ # std_val = float(np.std(values)) if values else np.nan
437
+
438
+ # # Flat: backward compatible with all existing print/save code
439
+ # flat_metrics[metric_name][key] = weighted_mean_values if metric_name != 'hd95' else mean_val
440
+
441
+ # # Rich: for extended reporting
442
+ # rich_metrics[metric_name][key] = {
443
+ # 'mean': mean_val,
444
+ # 'std': std_val,
445
+ # 'n': len(values)
446
+ # }
447
+
448
+ # # Mean across classes — same for both
449
+ # class_means = [
450
+ # flat_metrics[metric_name][f'class_{c}']
451
+ # for c in range(num_classes)
452
+ # if c!=0 and not np.isnan(flat_metrics[metric_name][f'class_{c}'])
453
+ # ]
454
+ # mean_across_classes = float(np.mean(class_means)) if class_means else np.nan
455
+ # flat_metrics[metric_name]['mean'] = mean_across_classes
456
+ # rich_metrics[metric_name]['mean'] = mean_across_classes
457
+
458
+ # return flat_metrics, rich_metrics
459
+
460
+ def aggregate_patient_metrics(per_patient_metrics, num_classes):
461
+ """
462
+ Returns both a flat structure (compatible with original overall_metrics)
463
+ and an extended structure with std/n for richer reporting.
464
+
465
+ Includes lesion-level metrics (connected-component analysis):
466
+ - lesion_sensitivity : mean across patients of (tp_lesions / n_gt_lesions)
467
+ - lesion_precision : mean across patients of (tp_lesions / n_pred_lesions)
468
+ - lesion_f1 : mean across patients of harmonic mean of the above
469
+ - n_gt_lesions : total GT lesions summed across all patients
470
+ - n_pred_lesions : total predicted lesion clusters summed across all patients
471
+ - tp_lesions : total TP lesions summed across all patients
472
+ - fn_lesions : total FN lesions summed across all patients
473
+ - fp_lesions : total FP lesions summed across all patients
474
+ """
475
+ # ── Voxel-level metrics (unchanged) ─────────────────────────────────────
476
+ voxel_metric_names = ['dice', 'precision', 'recall', 'iou', 'specificity', 'hd95', 'TP']
477
+ flat_metrics = {m: {} for m in voxel_metric_names}
478
+ rich_metrics = {m: {} for m in voxel_metric_names}
479
+
480
+ for metric_name in voxel_metric_names:
481
+ for class_idx in range(num_classes):
482
+ if class_idx == 0:
483
+ continue
484
+
485
+ key = f'class_{class_idx}'
486
+
487
+ values = [
488
+ per_patient_metrics[pid][metric_name][key]
489
+ for pid in per_patient_metrics
490
+ if key in per_patient_metrics[pid][metric_name]
491
+ and not np.isnan(per_patient_metrics[pid][metric_name][key])
492
+ ]
493
+
494
+ TP_values = [
495
+ per_patient_metrics[pid]['TP'][key]
496
+ for pid in per_patient_metrics
497
+ if key in per_patient_metrics[pid]['TP']
498
+ and not np.isnan(per_patient_metrics[pid]['TP'][key])
499
+ ]
500
+
501
+ weighted_mean_values = np.sum(
502
+ (np.array(values) * np.array(TP_values)) / np.sum(np.array(TP_values))
503
+ )
504
+
505
+ mean_val = float(np.mean(values)) if values else np.nan
506
+ std_val = float(np.std(values)) if values else np.nan
507
+
508
+ flat_metrics[metric_name][key] = weighted_mean_values if metric_name != 'hd95' else mean_val
509
+ rich_metrics[metric_name][key] = {
510
+ 'mean': mean_val,
511
+ 'std': std_val,
512
+ 'n': len(values)
513
+ }
514
+
515
+ # Mean across classes
516
+ class_means = [
517
+ flat_metrics[metric_name][f'class_{c}']
518
+ for c in range(num_classes)
519
+ if c != 0 and not np.isnan(flat_metrics[metric_name][f'class_{c}'])
520
+ ]
521
+ mean_across_classes = float(np.mean(class_means)) if class_means else np.nan
522
+ flat_metrics[metric_name]['mean'] = mean_across_classes
523
+ rich_metrics[metric_name]['mean'] = mean_across_classes
524
+
525
+ # ── Lesion-level metrics (new) ───────────────────────────────────────────
526
+ # Scalar fields: averaged across patients (mean ± std)
527
+ lesion_scalar_keys = ['lesion_sensitivity', 'lesion_precision', 'lesion_f1']
528
+ # Count fields: summed across patients (total pool)
529
+ lesion_count_keys = ['n_gt_lesions', 'n_pred_lesions', 'tp_lesions', 'fn_lesions', 'fp_lesions']
530
+
531
+ flat_metrics['lesion'] = {}
532
+ rich_metrics['lesion'] = {}
533
+
534
+ for class_idx in range(num_classes):
535
+ if class_idx <= 1: # skip background and ventricles
536
+ continue
537
+
538
+ key = f'class_{class_idx}'
539
+ flat_metrics['lesion'][key] = {}
540
+ rich_metrics['lesion'][key] = {}
541
+
542
+ # --- Scalar metrics: mean ± std across patients ---
543
+ for sk in lesion_scalar_keys:
544
+ vals = [
545
+ per_patient_metrics[pid]['lesion'][key][sk]
546
+ for pid in per_patient_metrics
547
+ if 'lesion' in per_patient_metrics[pid]
548
+ and key in per_patient_metrics[pid]['lesion']
549
+ ]
550
+ mean_val = float(np.mean(vals)) if vals else np.nan
551
+ std_val = float(np.std(vals)) if vals else np.nan
552
+ flat_metrics['lesion'][key][sk] = mean_val
553
+ rich_metrics['lesion'][key][sk] = {
554
+ 'mean': mean_val,
555
+ 'std': std_val,
556
+ 'n': len(vals)
557
+ }
558
+
559
+ # --- Count metrics: sum across patients ---
560
+ for ck in lesion_count_keys:
561
+ vals = [
562
+ per_patient_metrics[pid]['lesion'][key][ck]
563
+ for pid in per_patient_metrics
564
+ if 'lesion' in per_patient_metrics[pid]
565
+ and key in per_patient_metrics[pid]['lesion']
566
+ ]
567
+ flat_metrics['lesion'][key][ck] = int(np.sum(vals)) if vals else 0
568
+ rich_metrics['lesion'][key][ck] = int(np.sum(vals)) if vals else 0
569
+
570
+ # Mean lesion scalars across foreground classes
571
+ for sk in lesion_scalar_keys:
572
+ class_vals = [
573
+ flat_metrics['lesion'][f'class_{c}'][sk]
574
+ for c in range(num_classes)
575
+ if c > 1 and not np.isnan(flat_metrics['lesion'][f'class_{c}'][sk])
576
+ ]
577
+ mean_across = float(np.mean(class_vals)) if class_vals else np.nan
578
+ flat_metrics['lesion'][f'mean_{sk}'] = mean_across
579
+ rich_metrics['lesion'][f'mean_{sk}'] = mean_across
580
+
581
+ # Summed counts across foreground classes
582
+ for ck in lesion_count_keys:
583
+ flat_metrics['lesion'][f'total_{ck}'] = int(np.sum([
584
+ flat_metrics['lesion'][f'class_{c}'][ck]
585
+ for c in range(num_classes) if c > 1
586
+ ]))
587
+ rich_metrics['lesion'][f'total_{ck}'] = flat_metrics['lesion'][f'total_{ck}']
588
+
589
+ return flat_metrics, rich_metrics
590
+
591
+
592
+ ###################### Original Visualization Functions ######################
593
+
594
+ def visualize_prediction(flair, ground_truth, prediction,
595
+ probability_map, save_path,
596
+ sample_id, num_classes):
597
+ """
598
+ Create comprehensive visualization of prediction
599
+
600
+ Args:
601
+ flair: Input FLAIR image (H, W)
602
+ ground_truth: Ground truth mask (H, W)
603
+ prediction: Predicted mask (H, W)
604
+ probability_map: Max probability map (H, W)
605
+ save_path: Path to save figure
606
+ sample_id: Sample identifier
607
+ num_classes: Number of classes
608
+ """
609
+ fig, axes = plt.subplots(2, 3, figsize=(18, 12))
610
+
611
+ # Input FLAIR
612
+ axes[0, 0].imshow(flair, cmap='gray')
613
+ axes[0, 0].set_title('Input FLAIR', fontsize=14, fontweight='bold')
614
+ axes[0, 0].axis('off')
615
+
616
+ # Ground truth
617
+ im1 = axes[0, 1].imshow(ground_truth, cmap='jet', vmin=0, vmax=num_classes-1)
618
+ axes[0, 1].set_title('Ground Truth', fontsize=14, fontweight='bold')
619
+ axes[0, 1].axis('off')
620
+ plt.colorbar(im1, ax=axes[0, 1], fraction=0.046, pad=0.04)
621
+
622
+ # Prediction
623
+ im2 = axes[0, 2].imshow(prediction, cmap='jet', vmin=0, vmax=num_classes-1)
624
+ axes[0, 2].set_title('Prediction', fontsize=14, fontweight='bold')
625
+ axes[0, 2].axis('off')
626
+ plt.colorbar(im2, ax=axes[0, 2], fraction=0.046, pad=0.04)
627
+
628
+ # Max probability
629
+ im3 = axes[1, 0].imshow(probability_map, cmap='viridis', vmin=0, vmax=1)
630
+ axes[1, 0].set_title('Prediction Confidence', fontsize=14, fontweight='bold')
631
+ axes[1, 0].axis('off')
632
+ plt.colorbar(im3, ax=axes[1, 0], fraction=0.046, pad=0.04)
633
+
634
+ # Error map
635
+ error_map = (prediction != ground_truth).astype(float)
636
+ im4 = axes[1, 1].imshow(error_map, cmap='Reds', vmin=0, vmax=1)
637
+ axes[1, 1].set_title('Error Map (Red=Wrong)', fontsize=14, fontweight='bold')
638
+ axes[1, 1].axis('off')
639
+ plt.colorbar(im4, ax=axes[1, 1], fraction=0.046, pad=0.04)
640
+
641
+ # Overlay: FLAIR + Prediction contours
642
+ axes[1, 2].imshow(flair, cmap='gray')
643
+ # Create contours for each class
644
+ from scipy import ndimage
645
+ for class_idx in range(1, num_classes): # Skip background
646
+ class_mask = (prediction == class_idx)
647
+ contours = class_mask ^ ndimage.binary_erosion(class_mask)
648
+ if np.any(contours):
649
+ axes[1, 2].contour(contours, colors=[plt.cm.jet(class_idx/(num_classes-1))], linewidths=1.5)
650
+ axes[1, 2].set_title('FLAIR + Prediction Overlay', fontsize=14, fontweight='bold')
651
+ axes[1, 2].axis('off')
652
+
653
+ plt.suptitle(f'Sample: {sample_id}', fontsize=16, fontweight='bold', y=0.98)
654
+ plt.tight_layout()
655
+ plt.savefig(save_path, dpi=150, bbox_inches='tight')
656
+ plt.close()
657
+
658
+
659
+ def visualize_prediction_short(flair, ground_truth, prediction,
660
+ probability_map, save_path,
661
+ sample_id, num_classes):
662
+ """
663
+ Create comprehensive visualization of prediction
664
+
665
+ Args:
666
+ flair: Input FLAIR image (H, W)
667
+ ground_truth: Ground truth mask (H, W)
668
+ prediction: Predicted mask (H, W)
669
+ probability_map: Max probability map (H, W)
670
+ save_path: Path to save figure
671
+ sample_id: Sample identifier
672
+ num_classes: Number of classes
673
+ """
674
+ fig, axes = plt.subplots(2, 1, figsize=(6, 12))
675
+
676
+ cmap = plt.cm.jet
677
+ flair_norm = (flair - flair.min()) / (flair.max() - flair.min() + 1e-8)
678
+ flair_rgb = np.stack([flair_norm] * 3, axis=-1)
679
+
680
+ for ax, mask, title in zip(axes, [ground_truth, prediction], ['Ground Truth Overlay', 'Prediction Overlay']):
681
+ mask_rgb = cmap(mask / (num_classes - 1))[..., :3] # (H, W, 3)
682
+ foreground = mask > 0
683
+ alpha = np.where(foreground, 0.6, 0.0)[..., np.newaxis] # fade non-background
684
+ blended = flair_rgb * (1 - alpha) + mask_rgb * alpha
685
+
686
+ ax.imshow(blended)
687
+ # ax.set_title(title, fontsize=14, fontweight='bold')
688
+ ax.axis('off')
689
+
690
+ # Shared colorbar
691
+ sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=0, vmax=num_classes - 1))
692
+ sm.set_array([])
693
+ # fig.colorbar(sm, ax=axes.ravel().tolist(), fraction=0.02, pad=0.04)
694
+
695
+ # plt.suptitle(f'Sample: {sample_id}', fontsize=16, fontweight='bold')
696
+ plt.tight_layout()
697
+ try:
698
+ plt.savefig(save_path, dpi=150, bbox_inches='tight')
699
+ except:
700
+ print(f"\n Unsaved image: {save_path}")
701
+ plt.close()
702
+
703
+
704
+ def save_prediction_as_nifti(prediction, save_path, reference_nifti=None):
705
+ """
706
+ Save prediction as NIfTI file
707
+
708
+ Args:
709
+ prediction: Prediction array (H, W) or (H, W, D)
710
+ save_path: Path to save NIfTI file
711
+ reference_nifti: Optional reference NIfTI for header info
712
+ """
713
+ if reference_nifti is not None:
714
+ # Use reference header
715
+ nifti_img = nib.Nifti1Image(prediction.astype(np.uint8), reference_nifti.affine, reference_nifti.header)
716
+ else:
717
+ # Create new NIfTI with identity affine
718
+ nifti_img = nib.Nifti1Image(prediction.astype(np.uint8), np.eye(4))
719
+
720
+ nib.save(nifti_img, save_path)
721
+
722
+
723
+ ###################### Post-processing Function ######################
724
+
725
+ def post_process_pred(pred_classes, num_classes=3, min_object_size=5, closing_kernel_size=2):
726
+ """
727
+ Post-process a single 2-D multi-class prediction slice.
728
+
729
+ Input
730
+ -----
731
+ pred_classes : np.ndarray of shape (H, W) — integer class labels
732
+ produced by tf.argmax(...).numpy()[0] inside the
733
+ inference loop (one slice at a time).
734
+ num_classes : 3 → classes are 0=BG, 1=Vent, 2=AbWMH
735
+ 4 → classes are 0=BG, 1=Vent, 2=NormWMH, 3=AbWMH
736
+ min_object_size : connected components smaller than this (pixels) are
737
+ removed after morphological cleaning. Default 5.
738
+ closing_kernel_size: radius of the disk used for binary_closing. Default 2.
739
+
740
+ Output
741
+ ------
742
+ post_pred : np.ndarray of shape (H, W), same dtype as pred_classes,
743
+ with cleaned and overlap-resolved integer class labels.
744
+
745
+ Processing pipeline (per class)
746
+ --------------------------------
747
+ 1. Extract binary mask for each foreground class from the label map.
748
+ 2. Apply binary_closing → fill small holes / bridge tiny gaps.
749
+ 3. Apply remove_small_objects → discard isolated noise specks.
750
+ 4. Resolve overlaps by anatomical priority:
751
+ Ventricles > Normal WMH > Abnormal WMH
752
+ (a higher-priority class always wins contested pixels)
753
+ 5. Reconstruct the integer label map from the cleaned binary masks.
754
+ """
755
+ from skimage.morphology import remove_small_objects, binary_erosion, binary_closing, disk, binary_dilation
756
+
757
+ kernel = disk(closing_kernel_size)
758
+
759
+ def clean(mask):
760
+ """Apply closing + small-object removal to a single binary mask."""
761
+ if not mask.any():
762
+ return mask
763
+ mask = binary_closing(mask, kernel)
764
+ # mask = binary_erosion(mask, disk(1))
765
+ mask = remove_small_objects(mask, min_size=min_object_size)
766
+ return mask
767
+
768
+ # ── 1. Extract per-class binary masks from the 2-D label map ────────────
769
+ vent_mask = (pred_classes == 1)
770
+
771
+ if num_classes == 4:
772
+ nwmh_mask = (pred_classes == 2)
773
+ abwmh_mask = (pred_classes == 3)
774
+ else:
775
+ # 3-class scenario: no Normal WMH, AbWMH is class 2
776
+ nwmh_mask = np.zeros_like(vent_mask)
777
+ abwmh_mask = (pred_classes == 2)
778
+
779
+ # ── 2-3. Morphological cleaning per class ───────────────────────────────
780
+ vent_mask = clean(vent_mask)
781
+ nwmh_mask = clean(nwmh_mask)
782
+ abwmh_mask = clean(abwmh_mask)
783
+
784
+ # ── 4. Resolve overlaps: higher-priority mask wins ───────────────────────
785
+ # Ventricles > Normal WMH > Abnormal WMH
786
+ nwmh_mask = nwmh_mask & ~vent_mask # NormWMH cannot overlap Vent
787
+ abwmh_mask = abwmh_mask & ~vent_mask # AbWMH cannot overlap Vent
788
+ abwmh_mask = abwmh_mask & ~nwmh_mask # AbWMH cannot overlap NormWMH
789
+
790
+ # ── 5. Reconstruct the integer label map ─────────────────────────────────
791
+ post_pred = np.zeros_like(pred_classes) # background = 0
792
+ post_pred[vent_mask] = 1
793
+
794
+ if num_classes == 4:
795
+ post_pred[nwmh_mask] = 2
796
+ post_pred[abwmh_mask] = 3
797
+ else:
798
+ post_pred[abwmh_mask] = 2
799
+
800
+ return post_pred
801
+
802
+
803
+ ###################### Main Inference Function ######################
804
+
805
+ def run_inference(config: InferenceConfig):
806
+ """
807
+ Main inference function
808
+
809
+ Args:
810
+ config: InferenceConfig object
811
+
812
+ Returns:
813
+ Dictionary containing all predictions and metrics
814
+ """
815
+ print("\n" + "="*70)
816
+ print(f"RUNNING INFERENCE")
817
+ print("="*70)
818
+
819
+ # Initialize data loader
820
+ data_config = DataConfig()
821
+ data_loader = P2DataLoader(data_config)
822
+
823
+ # Load test dataset
824
+ print("Loading test data...")
825
+ test_dataset = data_loader.create_dataset_for_fold(
826
+ fold_id=config.fold_id,
827
+ split='test',
828
+ preprocessing=config.preprocessing,
829
+ class_scenario=config.class_scenario,
830
+ batch_size=config.batch_size,
831
+ shuffle=False
832
+ )
833
+
834
+ # Get dataset size
835
+ test_size = tf.data.experimental.cardinality(test_dataset).numpy()
836
+ if test_size < 0:
837
+ test_size = sum(1 for _ in test_dataset)
838
+ test_dataset = data_loader.create_dataset_for_fold(
839
+ fold_id=config.fold_id, split='test',
840
+ preprocessing=config.preprocessing,
841
+ class_scenario=config.class_scenario,
842
+ batch_size=config.batch_size, shuffle=False
843
+ )
844
+
845
+ print(f"Test samples: {test_size}\n")
846
+
847
+ # Load model
848
+ print(f"Loading model from: {config.model_path}")
849
+ try:
850
+ if config.architecture_name == 'unet':
851
+ from unet_model import build_unet_3class as build_specific_3class # must be updated with the actual used model for traininig
852
+ elif config.architecture_name == 'attnunet':
853
+ from attn_unet_model import build_attention_unet_3class as build_specific_3class
854
+ elif config.architecture_name == 'dlv3unet':
855
+ from dlv3_unet_model_GN import build_deeplabv3_unet_3class as build_specific_3class
856
+ elif config.architecture_name == 'transunet':
857
+ from trans_unet_model import build_trans_unet_3class as build_specific_3class
858
+ else:
859
+ print(f"❌ Error loading model: Invalid Model Name")
860
+ raise
861
+
862
+ # Build model architecture first
863
+ generator = build_specific_3class(
864
+ input_shape=(256, 256, 1),
865
+ num_classes=config.num_classes
866
+ )
867
+
868
+ # Load weights
869
+ generator.load_weights(str(config.model_path))
870
+ print("✅ Model loaded successfully\n")
871
+
872
+ except Exception as e:
873
+ print(f"❌ Error loading model: {e}")
874
+ raise
875
+
876
+ # Initialize storage - keyed by patient ID
877
+ patient_results = defaultdict(lambda: {
878
+ 'predictions': [],
879
+ 'ground_truths': [],
880
+ 'probabilities': [],
881
+ 'flairs': [],
882
+ 'slice_indices': []
883
+ })
884
+ sample_ids = []
885
+
886
+ # Run inference
887
+ print("Running inference on test set...")
888
+ test_bar = tqdm(test_dataset, total=test_size, desc="Inference")
889
+
890
+ for idx, (paired_input, target_mask, patient_id_tensor, slice_num_tensor) in enumerate(test_bar):
891
+
892
+ patient_id = patient_id_tensor.numpy()[0].decode('utf-8') # batch dim + bytes→str
893
+ slice_num = int(slice_num_tensor.numpy()[0])
894
+
895
+ sample_ids.append(f"{patient_id}_slice_{slice_num:03d}")
896
+
897
+ # Prepare input
898
+ flair_normalized = prepare_input(paired_input)
899
+
900
+ # Generate prediction
901
+ prediction_softmax = generator(flair_normalized, training=False)
902
+
903
+ # Convert to class labels
904
+ pred_classes = tf.argmax(prediction_softmax, axis=-1).numpy()[0]
905
+ max_prob = tf.reduce_max(prediction_softmax, axis=-1).numpy()[0]
906
+ ground_truth = target_mask.numpy()[0]
907
+ flair = flair_normalized.numpy()[0, :, :, 0]
908
+
909
+ # Post-process the predictions
910
+ # pred_classes_post = post_process_pred(pred_classes, num_classes=config.num_classes)
911
+
912
+ # Store per-patient
913
+ patient_results[patient_id]['predictions'].append(pred_classes)
914
+ patient_results[patient_id]['ground_truths'].append(ground_truth)
915
+ patient_results[patient_id]['probabilities'].append(max_prob)
916
+ patient_results[patient_id]['flairs'].append(flair)
917
+ patient_results[patient_id]['slice_indices'].append(slice_num)
918
+
919
+ # Create visualization
920
+ if idx % 10 == 0 or True: # Visualize every 10th sample
921
+ # viz_path = config.visualizations_dir / f"visualization_{idx:04d}.png"
922
+ viz_path = config.visualizations_dir / f"{sample_ids[-1]}.png"
923
+ visualize_prediction_short(
924
+ flair, ground_truth, pred_classes,
925
+ max_prob, viz_path,
926
+ sample_ids[-1], config.num_classes
927
+ )
928
+
929
+ print("\n✅ Inference complete!\n")
930
+
931
+ # Compute overall metrics
932
+ print("Computing metrics...")
933
+ exclude_class = None
934
+ per_patient_metrics = {}
935
+
936
+ for patient_id, data in patient_results.items():
937
+ # Sort slices by anatomical order
938
+ order = np.argsort(data['slice_indices'])
939
+
940
+ gt_volume = np.array(data['ground_truths'])[order] # (S, H, W)
941
+ pred_volume = np.array(data['predictions'])[order] # (S, H, W)
942
+
943
+ per_patient_metrics[patient_id] = compute_metrics_from_predictions(
944
+ gt_volume,
945
+ pred_volume,
946
+ config.num_classes
947
+ )
948
+ print(f"\nPatint_id : {patient_id} , Stats: {per_patient_metrics[patient_id]}\n")
949
+
950
+ pm = per_patient_metrics[patient_id]
951
+ print(f"\nPatient_id: {patient_id}")
952
+ print(f" Voxel — Dice: { {k: round(v,4) for k,v in pm['dice'].items()} }")
953
+ if 'lesion' in pm:
954
+ for cls, ld in pm['lesion'].items():
955
+ print(f" Lesion [{cls}] — "
956
+ f"GT:{ld['n_gt_lesions']} Pred:{ld['n_pred_lesions']} "
957
+ f"TP:{ld['tp_lesions']} FP:{ld['fp_lesions']} FN:{ld['fn_lesions']} "
958
+ f"Sens:{ld['lesion_sensitivity']:.3f} Prec:{ld['lesion_precision']:.3f} "
959
+ f"F1:{ld['lesion_f1']:.3f}")
960
+
961
+ # Aggregate across patients
962
+ overall_metrics, overall_metrics_rich = aggregate_patient_metrics(
963
+ per_patient_metrics, config.num_classes
964
+ )
965
+ # overall_metrics → drop-in replacement for old overall_metrics, all print/save code unchanged
966
+ # overall_metrics_rich → use wherever we want mean ± std reporting
967
+
968
+ # Print standard metrics
969
+ print("\n" + "="*70)
970
+ print("STANDARD METRICS (Class vs Rest)")
971
+ print("="*70)
972
+
973
+ print("\nClass-wise Dice Scores:")
974
+ for class_idx, class_name in enumerate(config.class_names):
975
+ if exclude_class is not None and class_idx == exclude_class:
976
+ continue
977
+ key = f'class_{class_idx}'
978
+ if key in overall_metrics['dice']:
979
+ print(f" {class_name}: {overall_metrics['dice'][key]:.4f}")
980
+ print(f" Mean Dice: {overall_metrics['dice']['mean']:.4f}")
981
+
982
+ print("\nClass-wise Precision:")
983
+ for class_idx, class_name in enumerate(config.class_names):
984
+ if exclude_class is not None and class_idx == exclude_class:
985
+ continue
986
+ key = f'class_{class_idx}'
987
+ if key in overall_metrics['precision']:
988
+ print(f" {class_name}: {overall_metrics['precision'][key]:.4f}")
989
+ print(f" Mean Precision: {overall_metrics['precision']['mean']:.4f}")
990
+
991
+ print("\nClass-wise Recall:")
992
+ for class_idx, class_name in enumerate(config.class_names):
993
+ if exclude_class is not None and class_idx == exclude_class:
994
+ continue
995
+ key = f'class_{class_idx}'
996
+ if key in overall_metrics['recall']:
997
+ print(f" {class_name}: {overall_metrics['recall'][key]:.4f}")
998
+ print(f" Mean Recall: {overall_metrics['recall']['mean']:.4f}")
999
+
1000
+ print("\nClass-wise IoU:")
1001
+ for class_idx, class_name in enumerate(config.class_names):
1002
+ if exclude_class is not None and class_idx == exclude_class:
1003
+ continue
1004
+ key = f'class_{class_idx}'
1005
+ if key in overall_metrics['iou']:
1006
+ print(f" {class_name}: {overall_metrics['iou'][key]:.4f}")
1007
+ print(f" Mean IoU: {overall_metrics['iou']['mean']:.4f}")
1008
+
1009
+ print("\nClass-wise Specificity:")
1010
+ for class_idx, class_name in enumerate(config.class_names):
1011
+ if exclude_class is not None and class_idx == exclude_class:
1012
+ continue
1013
+ key = f'class_{class_idx}'
1014
+ if key in overall_metrics['specificity']:
1015
+ print(f" {class_name}: {overall_metrics['specificity'][key]:.4f}")
1016
+ print(f" Mean Specificity: {overall_metrics['specificity']['mean']:.4f}")
1017
+
1018
+ print("\nClass-wise HD95 (lower is better):")
1019
+ for class_idx, class_name in enumerate(config.class_names):
1020
+ if exclude_class is not None and class_idx == exclude_class:
1021
+ continue
1022
+ key = f'class_{class_idx}'
1023
+ if key in overall_metrics['hd95']:
1024
+ print(f" {class_name}: {overall_metrics['hd95'][key]:.4f}")
1025
+ print(f" Mean HD95: {overall_metrics['hd95']['mean']:.4f}")
1026
+
1027
+ print("="*70 + "\n")
1028
+
1029
+ # Print lesion-level metrics
1030
+ print("\n" + "="*70)
1031
+ print("LESION-LEVEL METRICS (Connected-Component Analysis)")
1032
+ print("="*70)
1033
+
1034
+ for class_idx, class_name in enumerate(config.class_names):
1035
+ if class_idx == 0:
1036
+ continue
1037
+ key = f'class_{class_idx}'
1038
+ if key not in overall_metrics.get('lesion', {}):
1039
+ continue
1040
+ ld = overall_metrics['lesion'][key]
1041
+ print(f"\n [{class_name}]")
1042
+ print(f" GT Lesions : {ld['n_gt_lesions']}")
1043
+ print(f" Predicted Lesions : {ld['n_pred_lesions']}")
1044
+ print(f" TP Lesions : {ld['tp_lesions']}")
1045
+ print(f" FP Lesions : {ld['fp_lesions']}")
1046
+ print(f" FN Lesions : {ld['fn_lesions']}")
1047
+ print(f" Lesion Sensitivity : {ld['lesion_sensitivity']:.4f}")
1048
+ print(f" Lesion Precision : {ld['lesion_precision']:.4f}")
1049
+ print(f" Lesion F1 : {ld['lesion_f1']:.4f}")
1050
+
1051
+ print(f"\n [Summary across foreground classes]")
1052
+ print(f" Total GT Lesions : {overall_metrics['lesion']['total_n_gt_lesions']}")
1053
+ print(f" Total Pred Lesions : {overall_metrics['lesion']['total_n_pred_lesions']}")
1054
+ print(f" Total TP Lesions : {overall_metrics['lesion']['total_tp_lesions']}")
1055
+ print(f" Total FP Lesions : {overall_metrics['lesion']['total_fp_lesions']}")
1056
+ print(f" Total FN Lesions : {overall_metrics['lesion']['total_fn_lesions']}")
1057
+ print(f" Mean Lesion Sensitivity : {overall_metrics['lesion']['mean_lesion_sensitivity']:.4f}")
1058
+ print(f" Mean Lesion Precision : {overall_metrics['lesion']['mean_lesion_precision']:.4f}")
1059
+ print(f" Mean Lesion F1 : {overall_metrics['lesion']['mean_lesion_f1']:.4f}")
1060
+ print("="*70 + "\n")
1061
+
1062
+ # Save all metrics to JSON
1063
+ metrics_file = config.metrics_dir / "test_metrics_complete.json"
1064
+
1065
+ def convert_to_serializable(obj):
1066
+ """Convert numpy types to Python native types"""
1067
+ if isinstance(obj, dict):
1068
+ return {k: convert_to_serializable(v) for k, v in obj.items()}
1069
+ elif isinstance(obj, (np.integer, np.int64, np.int32)):
1070
+ return int(obj)
1071
+ elif isinstance(obj, (np.floating, np.float64, np.float32)):
1072
+ return float(obj)
1073
+ elif isinstance(obj, np.ndarray):
1074
+ return obj.tolist()
1075
+ else:
1076
+ return obj
1077
+
1078
+ metrics_to_save = {
1079
+ 'config': {
1080
+ 'variant': int(config.variant),
1081
+ 'preprocessing': config.preprocessing,
1082
+ 'class_scenario': config.class_scenario,
1083
+ 'fold_id': int(config.fold_id),
1084
+ 'num_classes': int(config.num_classes),
1085
+ 'class_names': config.class_names,
1086
+ 'architecture_name': config.architecture_name,
1087
+ 'model_name': config.model_name,
1088
+ 'test_samples': int(test_size)
1089
+ },
1090
+ 'metrics': convert_to_serializable(overall_metrics)
1091
+ }
1092
+
1093
+ with open(metrics_file, 'w') as f:
1094
+ json.dump(metrics_to_save, f, indent=2)
1095
+
1096
+ print(f"\n✅ All metrics saved to: {metrics_file}")
1097
+ # print(f"✅ Predictions saved to: {config.predictions_dir}")
1098
+ print(f"✅ Visualizations saved to: {config.visualizations_dir}")
1099
+
1100
+ # Return results
1101
+ return {
1102
+ 'patients_results': patient_results,
1103
+ 'metrics': overall_metrics,
1104
+ 'rich_metrics': overall_metrics_rich
1105
+ }
1106
+
1107
+
1108
+ ###################### Main Execution ######################
1109
+
1110
+ if __name__ == "__main__":
1111
+ # Run inference
1112
+
1113
+ preprocess_options = ['standard'] # ['zoomed', 'standard']
1114
+ scenarios = ['3class'] # ['3class', '4class']
1115
+ fold_numbers = list(np.array([0, 1, 2, 3]))
1116
+
1117
+ for fold_number in fold_numbers:
1118
+ for preprocess_option in preprocess_options:
1119
+ for scenario in scenarios:
1120
+
1121
+ config = InferenceConfig(
1122
+ variant=1,
1123
+ preprocessing=preprocess_option,
1124
+ class_scenario=scenario,
1125
+ fold_id=fold_number,
1126
+ model_name='best_dice_model.h5',
1127
+ architecture_name='unet' # a choice from ['unet', 'attnunet', 'dlv3unet', 'transunet']
1128
+ )
1129
+
1130
+ results = run_inference(config)
1131
+
1132
+ # ── Error Analysis ──────────��───────────────────────────
1133
+ error_results = run_error_analysis(
1134
+ results=results,
1135
+ config=config,
1136
+ top_n_slices=300, # visualise N hardest slices
1137
+ top_n_patients=20, # patient summary plots
1138
+ fg_dice_weight=0.7, # tunable ranking weights
1139
+ error_rate_weight=0.2,
1140
+ confidence_weight=0.2,
1141
+ )
1142
+ # ────────────────────────────────────────────────────────
1143
+
1144
+ print("\n" + "="*70)
1145
+ print("INFERENCE + ERROR ANALYSIS COMPLETE")
1146
+ print("="*70)
models/for_WMH_Vent/model_training_scripts/p4_run_experiments_all.py ADDED
@@ -0,0 +1,576 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ P4 Article - Run Multiple Variant Experiments
3
+ Updated runner script supporting all models
4
+
5
+ Supports:
6
+ - Variant 1: Baseline U-Net
7
+ - Variant 2: Attention U-Net
8
+ - Variant 3: DeepLabV3+ U-Net
9
+ - Variant 4: Trans U-Net
10
+
11
+ Usage:
12
+ # Single experiment
13
+ python p4_run_experiments_all.py --variant 2 --fold 0 --scenario standard_3class
14
+
15
+ # All scenarios for one variant+fold
16
+ python p4_run_experiments_all.py --variant 2 --fold 0
17
+
18
+ # All scenarios for one variant (all folds)
19
+ python p4_run_experiments_all.py --variant 2
20
+
21
+ # All scenarios (all folds and all variants)
22
+ python p4_run_experiments_all.py
23
+ """
24
+
25
+ import sys
26
+ import argparse
27
+ import subprocess
28
+ from pathlib import Path
29
+ import tensorflow as tf
30
+ import gc
31
+ from tensorflow.keras import backend as K
32
+
33
+ import p4_unet_viz
34
+
35
+
36
+ def clear_gpu_memory():
37
+ """Comprehensive GPU memory cleanup between experiments"""
38
+ print("\n" + "="*70)
39
+ print("CLEANING UP GPU MEMORY")
40
+ print("="*70)
41
+
42
+ # Clear Keras session
43
+ K.clear_session()
44
+ print("✅ Cleared Keras session")
45
+
46
+ # Force garbage collection
47
+ gc.collect()
48
+ print("✅ Ran garbage collection")
49
+
50
+ # Reset TensorFlow graphs
51
+ tf.compat.v1.reset_default_graph()
52
+ print("✅ Reset default graph")
53
+
54
+ # Additional cleanup for TF 2.x
55
+ try:
56
+ # Clear any cached tensors
57
+ tf.config.experimental.reset_memory_stats('GPU:0')
58
+ print("✅ Reset GPU memory stats")
59
+ except:
60
+ pass
61
+
62
+ print("="*70 + "\n")
63
+
64
+
65
+ def run_single_experiment(variant: int,
66
+ preprocessing: str,
67
+ class_scenario: str,
68
+ fold_id: int) -> bool:
69
+ """
70
+ Run a single experiment for specified variant
71
+
72
+ Args:
73
+ variant: 1 (baseline u-net) or 2 (attention u-net) or 3 (deeplabv3+ u-net) or 4 (trans u-net)
74
+ preprocessing: 'standard' or 'zoomed'
75
+ class_scenario: '3class' or '4class'
76
+ fold_id: 0-4
77
+
78
+ Returns:
79
+ True if successful, False otherwise
80
+ """
81
+ print("\n" + "="*80)
82
+ print(f"RUNNING: Variant {variant} | {preprocessing} | {class_scenario} | Fold {fold_id}")
83
+ print("="*80 + "\n")
84
+
85
+ try:
86
+ if variant == 1:
87
+ # Baseline unet
88
+ from p4_variant_all_net import ExperimentConfig, train_net
89
+
90
+ config = ExperimentConfig(
91
+ variant=variant,
92
+ preprocessing=preprocessing,
93
+ class_scenario=class_scenario,
94
+ fold_id=fold_id,
95
+ architecture_name='unet'
96
+ )
97
+
98
+ history, history_path = train_net(config)
99
+ p4_unet_viz.main_viz(history_path)
100
+
101
+ # Run Inference
102
+ from p4_inference import InferenceConfig, run_inference, run_error_analysis
103
+
104
+ config = InferenceConfig(
105
+ variant=variant,
106
+ preprocessing=preprocessing,
107
+ class_scenario=class_scenario,
108
+ fold_id=fold_id,
109
+ model_name='best_dice_model.h5',
110
+ architecture_name='unet'
111
+ )
112
+
113
+ results = run_inference(config)
114
+
115
+ # ── Error Analysis ──────────────────────────────────────
116
+ error_results = run_error_analysis(
117
+ results=results,
118
+ config=config,
119
+ top_n_slices=30, # visualise N hardest slices
120
+ top_n_patients=10, # patient summary plots
121
+ fg_dice_weight=0.6, # tunable ranking weights
122
+ error_rate_weight=0.2,
123
+ confidence_weight=0.2,
124
+ )
125
+
126
+ elif variant == 2:
127
+ # Attention unet
128
+ from p4_variant_all_net import ExperimentConfig, train_net
129
+
130
+ config = ExperimentConfig(
131
+ variant=variant,
132
+ preprocessing=preprocessing,
133
+ class_scenario=class_scenario,
134
+ fold_id=fold_id,
135
+ architecture_name='attnunet'
136
+ )
137
+
138
+ history, history_path = train_net(config)
139
+ p4_unet_viz.main_viz(history_path)
140
+
141
+ # Run Inference
142
+ from p4_inference import InferenceConfig, run_inference, run_error_analysis
143
+
144
+ config = InferenceConfig(
145
+ variant=variant,
146
+ preprocessing=preprocessing,
147
+ class_scenario=class_scenario,
148
+ fold_id=fold_id,
149
+ model_name='best_dice_model.h5',
150
+ architecture_name='attnunet'
151
+ )
152
+
153
+ results = run_inference(config)
154
+
155
+ # ── Error Analysis ──────────────────────────────────────
156
+ error_results = run_error_analysis(
157
+ results=results,
158
+ config=config,
159
+ top_n_slices=30, # visualise N hardest slices
160
+ top_n_patients=10, # patient summary plots
161
+ fg_dice_weight=0.6, # tunable ranking weights
162
+ error_rate_weight=0.2,
163
+ confidence_weight=0.2,
164
+ )
165
+
166
+ elif variant == 3:
167
+ # DeepLabV3+ unet
168
+ from p4_variant_all_net import ExperimentConfig, train_net
169
+
170
+ config = ExperimentConfig(
171
+ variant=variant,
172
+ preprocessing=preprocessing,
173
+ class_scenario=class_scenario,
174
+ fold_id=fold_id,
175
+ architecture_name='dlv3unet'
176
+ )
177
+
178
+ history, history_path = train_net(config)
179
+ p4_unet_viz.main_viz(history_path)
180
+
181
+ # Run Inference
182
+ from p4_inference import InferenceConfig, run_inference, run_error_analysis
183
+
184
+ config = InferenceConfig(
185
+ variant=variant,
186
+ preprocessing=preprocessing,
187
+ class_scenario=class_scenario,
188
+ fold_id=fold_id,
189
+ model_name='best_dice_model.h5',
190
+ architecture_name='dlv3unet'
191
+ )
192
+
193
+ results = run_inference(config)
194
+
195
+ # ── Error Analysis ──────────────────────────────────────
196
+ error_results = run_error_analysis(
197
+ results=results,
198
+ config=config,
199
+ top_n_slices=30, # visualise N hardest slices
200
+ top_n_patients=10, # patient summary plots
201
+ fg_dice_weight=0.6, # tunable ranking weights
202
+ error_rate_weight=0.2,
203
+ confidence_weight=0.2,
204
+ )
205
+
206
+ elif variant == 4:
207
+ # Trans unet
208
+ from p4_variant_all_net import ExperimentConfig, train_net
209
+
210
+ config = ExperimentConfig(
211
+ variant=variant,
212
+ preprocessing=preprocessing,
213
+ class_scenario=class_scenario,
214
+ fold_id=fold_id,
215
+ architecture_name='transunet'
216
+ )
217
+
218
+ history, history_path = train_net(config)
219
+ p4_unet_viz.main_viz(history_path)
220
+
221
+ # Run Inference
222
+ from p4_inference import InferenceConfig, run_inference, run_error_analysis
223
+
224
+ config = InferenceConfig(
225
+ variant=variant,
226
+ preprocessing=preprocessing,
227
+ class_scenario=class_scenario,
228
+ fold_id=fold_id,
229
+ model_name='best_dice_model.h5',
230
+ architecture_name='transunet'
231
+ )
232
+
233
+ results = run_inference(config)
234
+
235
+ # ── Error Analysis ──────────────────────────────────────
236
+ error_results = run_error_analysis(
237
+ results=results,
238
+ config=config,
239
+ top_n_slices=30, # visualise N hardest slices
240
+ top_n_patients=10, # patient summary plots
241
+ fg_dice_weight=0.6, # tunable ranking weights
242
+ error_rate_weight=0.2,
243
+ confidence_weight=0.2,
244
+ )
245
+
246
+ else:
247
+ raise ValueError(f"Unknown variant: {variant}")
248
+
249
+ print(f"\n✅ Experiment completed successfully!")
250
+ return True
251
+
252
+ except Exception as e:
253
+ print(f"\n❌ Experiment failed with error:")
254
+ print(f" {str(e)}")
255
+ import traceback
256
+ traceback.print_exc()
257
+ return False
258
+
259
+
260
+ def run_all_scenarios_for_variant_fold(variant: int, fold_id: int) -> dict:
261
+ """
262
+ Run all 4 scenarios for a given variant and fold
263
+
264
+ Args:
265
+ variant: 1 (baseline u-net) or 2 (attention u-net) or 3 (deeplabv3+ u-net) or 4 (trans u-net)
266
+ fold_id: 0-4
267
+
268
+ Returns:
269
+ Dictionary with results for each scenario
270
+ """
271
+ print("\n" + "="*80)
272
+ print(f"RUNNING ALL SCENARIOS FOR VARIANT {variant}, FOLD {fold_id}")
273
+ print("="*80)
274
+ print("\nTotal experiments: 4")
275
+ print(" 1. standard + 3class")
276
+ print(" 2. standard + 4class")
277
+ print(" 3. zoomed + 3class")
278
+ print(" 4. zoomed + 4class")
279
+ print("\n" + "="*80 + "\n")
280
+
281
+ experiments = [
282
+ {'preprocessing': 'zoomed', 'class_scenario': '4class'},
283
+ {'preprocessing': 'standard', 'class_scenario': '4class'},
284
+ {'preprocessing': 'zoomed', 'class_scenario': '3class'},
285
+ {'preprocessing': 'standard', 'class_scenario': '3class'},
286
+ ]
287
+
288
+ results = {}
289
+
290
+ for idx, scenario in enumerate(experiments, 1):
291
+ print(f"\n{'#'*80}")
292
+ print(f"SCENARIO {idx}/4: {scenario['preprocessing']} + {scenario['class_scenario']}")
293
+ print(f"{'#'*80}\n")
294
+
295
+ # Run in subprocess for complete memory isolation
296
+ import subprocess
297
+ import sys
298
+
299
+ cmd = [
300
+ sys.executable,
301
+ 'p4_run_experiments_all.py',
302
+ '--variant', str(variant),
303
+ '--fold', str(fold_id),
304
+ '--scenario', f"{scenario['preprocessing']}_{scenario['class_scenario']}"
305
+ ]
306
+
307
+ print(f"Running command: {' '.join(cmd)}\n")
308
+
309
+ try:
310
+ # Run experiment in separate process
311
+ result = subprocess.run(cmd, check=True, capture_output=False)
312
+
313
+ if result.returncode == 0:
314
+ exp_name = f"v{variant}_{scenario['preprocessing']}_{scenario['class_scenario']}_fold{fold_id}"
315
+ results[exp_name] = {'status': 'SUCCESS'}
316
+ print(f"\n✅ {exp_name} completed successfully")
317
+ else:
318
+ raise Exception(f"Process returned code {result.returncode}")
319
+
320
+ except subprocess.CalledProcessError as e:
321
+ exp_name = f"v{variant}_{scenario['preprocessing']}_{scenario['class_scenario']}_fold{fold_id}"
322
+ print(f"\n❌ Error in {scenario['preprocessing']} + {scenario['class_scenario']}")
323
+ print(f" Error: {str(e)}")
324
+ results[exp_name] = {
325
+ 'status': 'FAILED',
326
+ 'error': str(e)
327
+ }
328
+
329
+ # Ask user if they want to continue
330
+ response = input("\nContinue with remaining experiments? (y/n): ")
331
+ if response.lower() != 'y':
332
+ print("Stopping experiments...")
333
+ break
334
+
335
+ # Brief pause between experiments
336
+ import time
337
+ print("\n⏳ Waiting 5 seconds before next experiment...")
338
+ time.sleep(5)
339
+
340
+ # Summary
341
+ print("\n" + "="*80)
342
+ print(f"VARIANT {variant}, FOLD {fold_id} - SUMMARY")
343
+ print("="*80)
344
+
345
+ for exp_name, result in results.items():
346
+ status_icon = "✅" if result['status'] == 'SUCCESS' else "❌"
347
+ print(f"{status_icon} {exp_name}")
348
+
349
+ print("\n" + "="*80 + "\n")
350
+
351
+ return results
352
+
353
+
354
+ def run_all_folds_for_variant(variant: int) -> dict:
355
+ """
356
+ Run all scenarios for all folds for a given variant
357
+ Run all 4 experiments for all 5 folds
358
+ Total: 4 scenarios × 5 folds = 20 training runs
359
+
360
+ Args:
361
+ variant: 1 (baseline u-net) or 2 (attention u-net) or 3 (deeplabv3+ u-net) or 4 (trans u-net)
362
+
363
+ Returns:
364
+ Dictionary with results for all folds
365
+ """
366
+ print("\n" + "="*80)
367
+ print(f"RUNNING ALL FOLDS FOR VARIANT {variant}")
368
+ print("="*80)
369
+ print("\nTotal experiments: 4 scenarios × 5 folds = 20 training runs")
370
+ print("Estimated time: ~0.7 hour per experiment (with 60 epochs)")
371
+ print("Total estimated time: 10-20 hours")
372
+ print("\n" + "="*80 + "\n")
373
+
374
+ response = input("This will take a long time. Continue? (y/n): ")
375
+ if response.lower() != 'y':
376
+ print("Cancelled.")
377
+ return {}
378
+
379
+ all_results = {}
380
+
381
+ for fold_id in range(5):
382
+ print(f"\n{'='*80}")
383
+ print(f"STARTING FOLD {fold_id}")
384
+ print(f"{'='*80}\n")
385
+
386
+ fold_results = run_all_scenarios_for_variant_fold(variant, fold_id)
387
+ all_results[f'fold_{fold_id}'] = fold_results
388
+
389
+ # Final summary
390
+ print("\n" + "="*80)
391
+ print(f"VARIANT {variant} - ALL FOLDS COMPLETE")
392
+ print("="*80)
393
+
394
+ for fold_id in range(5):
395
+ fold_key = f'fold_{fold_id}'
396
+ if fold_key in all_results:
397
+ print(f"\nFold {fold_id}:")
398
+ for exp_name, result in all_results[fold_key].items():
399
+ status_icon = "✅" if result['status'] == 'SUCCESS' else "❌"
400
+ print(f" {status_icon} {exp_name}")
401
+
402
+ print("\n" + "="*80 + "\n")
403
+
404
+ return all_results
405
+
406
+
407
+ def compare_variants(fold_id: int = 0):
408
+ """
409
+ Compare results between baseline and attention variants and newloss variants
410
+
411
+ Args:
412
+ fold_id: Fold to compare (0-4)
413
+ """
414
+ print("\n" + "="*80)
415
+ print(f"COMPARING VARIANTS FOR FOLD {fold_id}")
416
+ print("="*80)
417
+
418
+ import json
419
+
420
+ scenarios = [
421
+ {'preprocessing': 'standard', 'class_scenario': '3class'},
422
+ {'preprocessing': 'standard', 'class_scenario': '4class'},
423
+ {'preprocessing': 'zoomed', 'class_scenario': '3class'},
424
+ {'preprocessing': 'zoomed', 'class_scenario': '4class'},
425
+ ]
426
+
427
+ results_dir = Path(f"results_fold_{fold_id}")
428
+
429
+ for scenario in scenarios:
430
+ print(f"\n{scenario['preprocessing']} + {scenario['class_scenario']}:")
431
+ print("-" * 60)
432
+
433
+ # Baseline (variant 1)
434
+ baseline_dir = results_dir / "models" / f"{scenario['preprocessing']}_{scenario['class_scenario']}" / f"fold_{fold_id}"
435
+ baseline_history = baseline_dir / "history.json"
436
+
437
+ # Attention (variant 2)
438
+ attention_dir = results_dir / "models" / f"{scenario['preprocessing']}_{scenario['class_scenario']}" / f"fold_{fold_id}_variant2"
439
+ attention_history = attention_dir / "history.json"
440
+
441
+ # Attention (variant 3)
442
+ newloss_dir = results_dir / "models" / f"{scenario['preprocessing']}_{scenario['class_scenario']}" / f"fold_{fold_id}_variant3"
443
+ newloss_history = newloss_dir / "history.json"
444
+
445
+ if baseline_history.exists() and attention_history.exists() and newloss_history.exists():
446
+ with open(baseline_history, 'r') as f:
447
+ baseline_data = json.load(f)
448
+
449
+ with open(attention_history, 'r') as f:
450
+ attention_data = json.load(f)
451
+
452
+ with open(newloss_history, 'r') as f:
453
+ newloss_data = json.load(f)
454
+
455
+ # Compare final validation losses
456
+ baseline_val = baseline_data['val_loss'][-1]
457
+ attention_val = attention_data['val_loss'][-1]
458
+ newloss_val = newloss_data['val_loss'][-1]
459
+
460
+ improvement_1_2 = ((baseline_val - attention_val) / baseline_val) * 100
461
+ improvement_1_3 = ((baseline_val - newloss_val) / baseline_val) * 100
462
+ improvement_2_3 = ((attention_val - newloss_val) / attention_val) * 100
463
+
464
+ print(f" Baseline Val Loss: {baseline_val:.4f}")
465
+ print(f" Attention Val Loss: {attention_val:.4f}")
466
+ print(f" NewLoss Val Loss: {newloss_val:.4f}")
467
+ print(f" Improvement by V2 on V1: {improvement_1_2:+.2f}%")
468
+ print(f" Improvement by V3 on V1: {improvement_1_3:+.2f}%")
469
+ print(f" Improvement by V3 on V2: {improvement_2_3:+.2f}%")
470
+
471
+ else:
472
+ if not baseline_history.exists():
473
+ print(f" ⚠️ Baseline results not found")
474
+ if not attention_history.exists():
475
+ print(f" ⚠️ Attention results not found")
476
+ if not newloss_history.exists():
477
+ print(f" ⚠️ NewLoss results not found")
478
+
479
+ print("\n" + "="*80 + "\n")
480
+
481
+
482
+ def main():
483
+ """Main entry point with argument parsing"""
484
+ parser = argparse.ArgumentParser(
485
+ description='Run P4 experiments for multiple variants',
486
+ formatter_class=argparse.RawDescriptionHelpFormatter,
487
+ epilog="""
488
+ Examples:
489
+ # Single experiment
490
+ python p4_run_experiments_all.py --variant 2 --fold 0 --scenario standard_3class
491
+
492
+ # All scenarios for variant 2, fold 0
493
+ python p4_run_experiments_all.py --variant 2 --fold 0
494
+
495
+ # All folds for variant 3
496
+ python p4_run_experiments_all.py --variant 2
497
+
498
+ # Compare results
499
+ python p4_run_experiments_all.py --compare --fold 0
500
+ """
501
+ )
502
+
503
+ parser.add_argument(
504
+ '--variant',
505
+ type=int,
506
+ choices=[1, 2, 3, 4],
507
+ help='variant: 1 (baseline u-net) or 2 (attention u-net) or 3 (deeplabv3+ u-net) or 4 (trans u-net)'
508
+ )
509
+
510
+ parser.add_argument(
511
+ '--fold',
512
+ type=int,
513
+ choices=[0, 1, 2, 3, 4],
514
+ help='Specific fold to train (0-4)'
515
+ )
516
+
517
+ parser.add_argument(
518
+ '--scenario',
519
+ type=str,
520
+ choices=['standard_3class', 'standard_4class', 'zoomed_3class', 'zoomed_4class'],
521
+ help='Specific scenario to train'
522
+ )
523
+
524
+ parser.add_argument(
525
+ '--compare',
526
+ action='store_true',
527
+ help='Compare results between variants'
528
+ )
529
+
530
+ args = parser.parse_args()
531
+
532
+ # Handle comparison mode (NOT READY YET!)
533
+ if args.compare:
534
+ fold_id = args.fold if args.fold is not None else 0
535
+ compare_variants(fold_id)
536
+ return
537
+
538
+ # Validate arguments
539
+ if args.variant is None:
540
+ parser.error("--variant is required (unless using --compare)")
541
+
542
+ # Single experiment
543
+ if args.scenario is not None:
544
+ preprocessing, class_scenario = args.scenario.split('_')
545
+ fold_id = args.fold if args.fold is not None else 0
546
+
547
+ print(f"\nRunning single experiment:")
548
+ print(f" Variant: {args.variant}")
549
+ print(f" Fold: {fold_id}")
550
+ print(f" Preprocessing: {preprocessing}")
551
+ print(f" Class scenario: {class_scenario}\n")
552
+
553
+ success = run_single_experiment(
554
+ variant=args.variant,
555
+ preprocessing=preprocessing,
556
+ class_scenario=class_scenario,
557
+ fold_id=fold_id
558
+ )
559
+
560
+ if success:
561
+ print("\n✅ Experiment complete!")
562
+ else:
563
+ print("\n❌ Experiment failed!")
564
+ sys.exit(1)
565
+
566
+ # All scenarios for specific fold
567
+ elif args.fold is not None:
568
+ run_all_scenarios_for_variant_fold(args.variant, args.fold)
569
+
570
+ # All scenarios for all folds
571
+ else:
572
+ run_all_folds_for_variant(args.variant)
573
+
574
+
575
+ if __name__ == "__main__":
576
+ main()
models/for_WMH_Vent/model_training_scripts/p4_unet_viz.py ADDED
@@ -0,0 +1,640 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ P4 - All U-Net models with Adaptive Loss (WCE + UFL)
3
+
4
+ WMH and Ventricles Segmentation with U-Net Models - Journal Paper Implementation
5
+ Three-class segmentation: Background vs Ventricles vs Abnormal WMH
6
+ Professional results saving and visualization for publication
7
+
8
+ This relates to our article:
9
+ "Deep Learning-Based Neuroanatomical Profiling Reveals Detailed Brain Changes:
10
+ A Large-Scale Multiple Sclerosis Study"
11
+
12
+ Features:
13
+ - Visualization of Results
14
+
15
+ Authors:
16
+ "Mahdi Bashiri Bawil, Mousa Shamsi, Abolhassan Shakeri Bavil"
17
+
18
+ Developer:
19
+ "Mahdi Bashiri Bawil"
20
+ """
21
+
22
+ import os
23
+ import json
24
+ import matplotlib.pyplot as plt
25
+ import numpy as np
26
+ from pathlib import Path
27
+
28
+
29
+ def load_history(filepath):
30
+ """Load training history from JSON file."""
31
+ with open(filepath, 'r') as f:
32
+ return json.load(f)
33
+
34
+ def detect_num_classes(history):
35
+ """Detect number of classes from val_metrics."""
36
+ if not history['val_metrics']:
37
+ return 3
38
+ first_metric = history['val_metrics'][0]
39
+ # Count only class_X keys, not 'mean'
40
+ num_classes = len([k for k in first_metric['dice'].keys() if k.startswith('class_')])
41
+ return num_classes
42
+
43
+ def get_class_names(num_classes):
44
+ """Get class names based on number of classes."""
45
+ if num_classes == 3:
46
+ return {
47
+ 'class_0': 'Background',
48
+ 'class_1': 'Ventricles',
49
+ 'class_2': 'Abnormal WMH'
50
+ }
51
+ elif num_classes == 4:
52
+ return {
53
+ 'class_0': 'Background',
54
+ 'class_1': 'Ventricles',
55
+ 'class_2': 'Normal WMH',
56
+ 'class_3': 'Abnormal WMH'
57
+ }
58
+ else:
59
+ return {f'class_{i}': f'Class {i}' for i in range(num_classes)}
60
+
61
+ def convert_to_native_types(obj):
62
+ """Recursively convert numpy types to native Python types for JSON serialization."""
63
+ if isinstance(obj, np.integer):
64
+ return int(obj)
65
+ elif isinstance(obj, np.floating):
66
+ return float(obj)
67
+ elif isinstance(obj, np.ndarray):
68
+ return obj.tolist()
69
+ elif isinstance(obj, dict):
70
+ return {key: convert_to_native_types(value) for key, value in obj.items()}
71
+ elif isinstance(obj, list):
72
+ return [convert_to_native_types(item) for item in obj]
73
+ else:
74
+ return obj
75
+
76
+ def find_best_epoch(history, num_classes):
77
+ """
78
+ Find the best epoch based on prioritized criteria:
79
+ 1. Highest Dice for abnormal WMH (top priority)
80
+ 2. Highest Dice for ventricles (secondary)
81
+ 3. Lowest validation loss (tertiary)
82
+ 4. ONLY consider epochs where beta > 0.95 (CRITICAL REQUIREMENT)
83
+
84
+ """
85
+ if not history['val_metrics']:
86
+ return None, {}
87
+
88
+ epochs = range(1, len(history['val_metrics']) + 1)
89
+ if 'beta_value' in history:
90
+ beta_values = history['beta_value']
91
+ else:
92
+ beta_values = [1] * len(history.get('val_loss', []))
93
+ history['beta_value'] = beta_values
94
+
95
+ # Find epochs where beta > 0.95 (CRITICAL FILTER)
96
+ valid_epoch_indices = [i for i, beta in enumerate(beta_values) if beta > 0.95]
97
+
98
+ if not valid_epoch_indices:
99
+ print("⚠️ WARNING: No epochs found with beta > 0.95!")
100
+ print(" Using all epochs for analysis (not recommended).")
101
+ valid_epoch_indices = list(range(len(beta_values)))
102
+
103
+ first_valid_epoch = valid_epoch_indices[0] + 1 if valid_epoch_indices else 1
104
+
105
+ # Determine the key for abnormal WMH
106
+ abnormal_key = 'class_3' if num_classes == 4 else 'class_2'
107
+ ventricles_key = 'class_1'
108
+
109
+ # Extract metrics
110
+ abnormal_dice = [m['dice'][abnormal_key] for m in history['val_metrics']]
111
+ ventricles_dice = [m['dice'][ventricles_key] for m in history['val_metrics']]
112
+ val_losses = history['val_loss']
113
+
114
+ # Find best epoch for abnormal WMH dice (only among valid epochs)
115
+ valid_abnormal_dice = [(i, abnormal_dice[i]) for i in valid_epoch_indices]
116
+ best_abnormal_idx = max(valid_abnormal_dice, key=lambda x: x[1])[0]
117
+ best_abnormal_epoch = best_abnormal_idx + 1
118
+ best_abnormal_dice = abnormal_dice[best_abnormal_idx]
119
+
120
+ # Find best epoch for ventricles dice (only among valid epochs)
121
+ valid_ventricles_dice = [(i, ventricles_dice[i]) for i in valid_epoch_indices]
122
+ best_ventricles_idx = max(valid_ventricles_dice, key=lambda x: x[1])[0]
123
+ best_ventricles_epoch = best_ventricles_idx + 1
124
+ best_ventricles_dice = ventricles_dice[best_ventricles_idx]
125
+
126
+ # Find best epoch for validation loss (only among valid epochs)
127
+ valid_val_losses = [(i, val_losses[i]) for i in valid_epoch_indices]
128
+ best_val_loss_idx = min(valid_val_losses, key=lambda x: x[1])[0]
129
+ best_val_loss_epoch = best_val_loss_idx + 1
130
+ best_val_loss = val_losses[best_val_loss_idx]
131
+
132
+ # Calculate composite score (weighted) - ONLY for valid epochs
133
+ composite_scores = [float('-inf')] * len(abnormal_dice)
134
+
135
+ for i in valid_epoch_indices:
136
+ # Normalize and weight: 60% abnormal dice, 30% ventricles dice, 10% inv val_loss
137
+ norm_abnormal = abnormal_dice[i]
138
+ norm_ventricles = ventricles_dice[i]
139
+
140
+ # Normalize validation loss among valid epochs only
141
+ valid_val_loss_values = [val_losses[j] for j in valid_epoch_indices]
142
+ max_val_loss = max(valid_val_loss_values) if valid_val_loss_values else 1
143
+ norm_val_loss = 1 - (val_losses[i] / max_val_loss) if max_val_loss > 0 else 0
144
+
145
+ composite = 0.6 * norm_abnormal + 0.3 * norm_ventricles + 0.1 * (1 - val_losses[i]) # norm_val_loss
146
+ composite_scores[i] = composite
147
+
148
+ best_overall_idx = int(np.argmax(composite_scores)) # Convert to int
149
+ best_overall_epoch = best_overall_idx + 1
150
+
151
+ # Get all metrics at best epoch
152
+ best_epoch_metrics = history['val_metrics'][best_overall_idx]
153
+
154
+ analysis = {
155
+ 'best_overall_epoch': int(best_overall_epoch),
156
+ 'best_overall_epoch_idx': int(best_overall_idx),
157
+ 'best_abnormal_epoch': int(best_abnormal_epoch),
158
+ 'best_abnormal_dice': float(best_abnormal_dice),
159
+ 'best_ventricles_epoch': int(best_ventricles_epoch),
160
+ 'best_ventricles_dice': float(best_ventricles_dice),
161
+ 'best_val_loss_epoch': int(best_val_loss_epoch),
162
+ 'best_val_loss': float(best_val_loss),
163
+ 'composite_score': float(composite_scores[best_overall_idx]),
164
+ 'abnormal_key': abnormal_key,
165
+ 'num_classes': int(num_classes),
166
+ 'first_valid_epoch': int(first_valid_epoch),
167
+ 'total_valid_epochs': int(len(valid_epoch_indices)),
168
+ 'beta_threshold': 0.95,
169
+ 'total_epochs': int(len(epochs)),
170
+ # Add complete metrics at best epoch
171
+ 'best_epoch_metrics': {
172
+ 'dice': best_epoch_metrics['dice'],
173
+ 'precision': best_epoch_metrics['precision'],
174
+ 'recall': best_epoch_metrics['recall'],
175
+ 'val_loss': float(val_losses[best_overall_idx]),
176
+ 'train_loss': float(history['train_loss'][best_overall_idx]),
177
+ 'wce_loss': float(history['wce_loss'][best_overall_idx]),
178
+ 'ufd_loss': float(history['ufd_loss'][best_overall_idx]),
179
+ 'val_loss_wce': float(history['val_loss_wce'][best_overall_idx]) if 'val_loss_wce' in history else None,
180
+ 'val_loss_ufd': float(history['val_loss_ufd'][best_overall_idx]) if 'val_loss_ufd' in history else None,
181
+ 'beta_value': float(beta_values[best_overall_idx])
182
+ }
183
+ }
184
+
185
+ # Convert all numpy types to native Python types
186
+ analysis = convert_to_native_types(analysis)
187
+
188
+ return best_overall_epoch, analysis
189
+
190
+ def save_analysis_json(analysis, output_path):
191
+ """Save analysis results to a JSON file."""
192
+ analysis = convert_to_native_types(analysis)
193
+ with open(output_path, 'w') as f:
194
+ json.dump(analysis, f, indent=2)
195
+ print(f"✓ Analysis saved to: {output_path}")
196
+
197
+ def save_enhanced_history(history, analysis, output_path):
198
+ """Save enhanced history with best epoch analysis appended."""
199
+ enhanced_history = history.copy()
200
+ enhanced_history['best_epoch_analysis'] = convert_to_native_types(analysis)
201
+ enhanced_history = convert_to_native_types(enhanced_history)
202
+
203
+ with open(output_path, 'w') as f:
204
+ json.dump(enhanced_history, f, indent=2)
205
+ print(f"✓ Enhanced history saved to: {output_path}")
206
+
207
+ def create_training_summary(history, analysis, class_names):
208
+ """Create a comprehensive training summary for easy parsing."""
209
+ summary = {
210
+ 'training_config': {
211
+ 'total_epochs': analysis['total_epochs'],
212
+ 'num_classes': analysis['num_classes'],
213
+ 'class_names': class_names,
214
+ 'model_type': 'a U-Net'
215
+ },
216
+ 'best_epoch_selection': {
217
+ 'overall_best_epoch': analysis['best_overall_epoch'],
218
+ 'composite_score': analysis['composite_score'],
219
+ 'selection_criteria': {
220
+ 'abnormal_wmh_weight': 0.6,
221
+ 'ventricles_weight': 0.3,
222
+ 'val_loss_weight': 0.1
223
+ }
224
+ },
225
+ 'priority_metrics': {
226
+ 'abnormal_wmh': {
227
+ 'best_epoch': analysis['best_abnormal_epoch'],
228
+ 'best_dice': analysis['best_abnormal_dice']
229
+ },
230
+ 'ventricles': {
231
+ 'best_epoch': analysis['best_ventricles_epoch'],
232
+ 'best_dice': analysis['best_ventricles_dice']
233
+ },
234
+ 'validation_loss': {
235
+ 'best_epoch': analysis['best_val_loss_epoch'],
236
+ 'best_loss': analysis['best_val_loss']
237
+ }
238
+ },
239
+ 'best_epoch_metrics': analysis['best_epoch_metrics'],
240
+ 'training_progression': {
241
+ 'final_epoch_metrics': {
242
+ 'dice': history['val_metrics'][-1]['dice'],
243
+ 'precision': history['val_metrics'][-1]['precision'],
244
+ 'recall': history['val_metrics'][-1]['recall'],
245
+ 'val_loss': history['val_loss'][-1],
246
+ 'train_loss': history['train_loss'][-1]
247
+ },
248
+ 'convergence_info': {
249
+ 'epochs_trained': len(history['val_loss'])
250
+ }
251
+ }
252
+ }
253
+
254
+ # Add epoch-by-epoch metrics for important classes
255
+ summary['epoch_progression'] = {
256
+ 'abnormal_wmh_dice': [m['dice'][analysis['abnormal_key']] for m in history['val_metrics']],
257
+ 'ventricles_dice': [m['dice']['class_1'] for m in history['val_metrics']],
258
+ 'mean_dice': [m['dice']['mean'] for m in history['val_metrics']],
259
+ 'val_loss': history['val_loss'],
260
+ 'train_loss': history['train_loss']
261
+ }
262
+
263
+ summary = convert_to_native_types(summary)
264
+
265
+ return summary
266
+
267
+ def plot_training_history(history, save_path='training_history.png'):
268
+ """Create comprehensive visualization of training history."""
269
+
270
+ num_classes = detect_num_classes(history)
271
+ class_names = get_class_names(num_classes)
272
+ best_epoch, analysis = find_best_epoch(history, num_classes)
273
+
274
+ epochs = range(1, len(history['train_loss']) + 1)
275
+
276
+ # Detect whether new-style history (with val_loss_wce / val_loss_ufd) is present
277
+ has_val_components = 'val_loss_wce' in history and 'val_loss_ufd' in history
278
+
279
+ # Create figure — 3 rows × 3 cols when val components exist, else 2×3
280
+ nrows = 3 if has_val_components else 2
281
+ fig = plt.figure(figsize=(18, nrows * 5))
282
+ gs = fig.add_gridspec(nrows, 3, hspace=0.35, wspace=0.3)
283
+
284
+ # Color scheme
285
+ colors = ['#2E86AB', '#A23B72', '#F18F01', '#C73E1D']
286
+ wce_color = '#4CAF50' # green – WCE
287
+ ufd_color = '#9C27B0' # purple – UFD
288
+ beta_color = '#FF5722' # deep-orange – beta
289
+
290
+ # 1. Training and Validation Loss (combined / weighted)
291
+ ax1 = fig.add_subplot(gs[0, 0])
292
+ ax1.plot(epochs, history['train_loss'], 'o-', linewidth=2, markersize=6,
293
+ color=colors[0], label='Train Loss')
294
+ ax1.plot(epochs, history['val_loss'], 's-', linewidth=2, markersize=6,
295
+ color=colors[2], label='Val Loss')
296
+ if best_epoch:
297
+ ax1.axvline(x=best_epoch, color='red', linestyle='--', linewidth=2,
298
+ alpha=0.7, label=f'Best Epoch ({best_epoch})')
299
+ ax1.set_xlabel('Epoch', fontsize=11, fontweight='bold')
300
+ ax1.set_ylabel('Loss', fontsize=11, fontweight='bold')
301
+ ax1.set_title('Training & Validation Loss\n(Combined Adaptive Loss)', fontsize=13, fontweight='bold')
302
+ ax1.legend(fontsize=9)
303
+ ax1.grid(True, alpha=0.3)
304
+
305
+ # 2. Dice Scores (excluding background)
306
+ ax2 = fig.add_subplot(gs[0, 1])
307
+ for i in range(1, num_classes): # Skip class_0 (background)
308
+ class_key = f'class_{i}'
309
+ dice_scores = [m['dice'][class_key] for m in history['val_metrics']]
310
+ ax2.plot(epochs, dice_scores, 'o-', linewidth=2, markersize=6,
311
+ label=class_names[class_key], color=colors[i % len(colors)])
312
+
313
+ if best_epoch:
314
+ ax2.axvline(x=best_epoch, color='red', linestyle='--', linewidth=2,
315
+ alpha=0.7, label=f'Best Epoch ({best_epoch})')
316
+ ax2.set_xlabel('Epoch', fontsize=11, fontweight='bold')
317
+ ax2.set_ylabel('Dice Score', fontsize=11, fontweight='bold')
318
+ ax2.set_title('Dice Scores by Class', fontsize=13, fontweight='bold')
319
+ ax2.legend(fontsize=9)
320
+ ax2.grid(True, alpha=0.3)
321
+ ax2.set_ylim([0, 1])
322
+
323
+ # 3. Precision Scores (excluding background)
324
+ ax3 = fig.add_subplot(gs[0, 2])
325
+ for i in range(1, num_classes):
326
+ class_key = f'class_{i}'
327
+ precision_scores = [m['precision'][class_key] for m in history['val_metrics']]
328
+ ax3.plot(epochs, precision_scores, 's-', linewidth=2, markersize=5,
329
+ label=class_names[class_key], color=colors[i % len(colors)])
330
+
331
+ if best_epoch:
332
+ ax3.axvline(x=best_epoch, color='red', linestyle='--', linewidth=2, alpha=0.7)
333
+ ax3.set_xlabel('Epoch', fontsize=11, fontweight='bold')
334
+ ax3.set_ylabel('Precision', fontsize=11, fontweight='bold')
335
+ ax3.set_title('Precision by Class', fontsize=13, fontweight='bold')
336
+ ax3.legend(fontsize=9)
337
+ ax3.grid(True, alpha=0.3)
338
+ ax3.set_ylim([0, 1])
339
+
340
+ # 4. Recall Scores (excluding background)
341
+ ax4 = fig.add_subplot(gs[1, 0])
342
+ for i in range(1, num_classes):
343
+ class_key = f'class_{i}'
344
+ recall_scores = [m['recall'][class_key] for m in history['val_metrics']]
345
+ ax4.plot(epochs, recall_scores, '^-', linewidth=2, markersize=5,
346
+ label=class_names[class_key], color=colors[i % len(colors)])
347
+
348
+ if best_epoch:
349
+ ax4.axvline(x=best_epoch, color='red', linestyle='--', linewidth=2, alpha=0.7)
350
+ ax4.set_xlabel('Epoch', fontsize=11, fontweight='bold')
351
+ ax4.set_ylabel('Recall', fontsize=11, fontweight='bold')
352
+ ax4.set_title('Recall by Class', fontsize=13, fontweight='bold')
353
+ ax4.legend(fontsize=9)
354
+ ax4.grid(True, alpha=0.3)
355
+ ax4.set_ylim([0, 1])
356
+
357
+ # 5. Mean Metrics
358
+ ax5 = fig.add_subplot(gs[1, 1])
359
+ mean_dice = [m['dice']['mean'] for m in history['val_metrics']]
360
+ mean_precision = [m['precision']['mean'] for m in history['val_metrics']]
361
+ mean_recall = [m['recall']['mean'] for m in history['val_metrics']]
362
+
363
+ ax5.plot(epochs, mean_dice, 'o-', linewidth=2, markersize=6,
364
+ color=colors[0], label='Mean Dice')
365
+ ax5.plot(epochs, mean_precision, 's-', linewidth=2, markersize=5,
366
+ color=colors[1], label='Mean Precision')
367
+ ax5.plot(epochs, mean_recall, '^-', linewidth=2, markersize=5,
368
+ color=colors[2], label='Mean Recall')
369
+
370
+ if best_epoch:
371
+ ax5.axvline(x=best_epoch, color='red', linestyle='--', linewidth=2, alpha=0.7)
372
+ ax5.set_xlabel('Epoch', fontsize=11, fontweight='bold')
373
+ ax5.set_ylabel('Score', fontsize=11, fontweight='bold')
374
+ ax5.set_title('Mean Validation Metrics', fontsize=13, fontweight='bold')
375
+ ax5.legend(fontsize=9)
376
+ ax5.grid(True, alpha=0.3)
377
+ ax5.set_ylim([0, 1])
378
+
379
+ # ── New Row 3 plots (only when val components are available) ──────────────
380
+ if has_val_components:
381
+ # 7. Training Loss Components (WCE vs UFD, train-side)
382
+ ax7 = fig.add_subplot(gs[2, 0])
383
+ ax7.plot(epochs, list(1*np.array(history['wce_loss'])), 'o-', linewidth=2, markersize=5,
384
+ color=wce_color, label='Train WCE Loss x10')
385
+ ax7.plot(epochs, history['ufd_loss'], 's-', linewidth=2, markersize=5,
386
+ color=ufd_color, label='Train UFD Loss')
387
+ ax7.plot(epochs, list(1*np.array(history['val_loss_wce'])), 'o--', linewidth=1.5, markersize=4,
388
+ color=wce_color, alpha=0.6, label='Val WCE Loss x10')
389
+ ax7.plot(epochs, history['val_loss_ufd'], 's--', linewidth=1.5, markersize=4,
390
+ color=ufd_color, alpha=0.6, label='Val UFD Loss')
391
+ if best_epoch:
392
+ ax7.axvline(x=best_epoch, color='red', linestyle='--', linewidth=2,
393
+ alpha=0.7, label=f'Best Epoch ({best_epoch})')
394
+ ax7.set_xlabel('Epoch', fontsize=11, fontweight='bold')
395
+ ax7.set_ylabel('Loss', fontsize=11, fontweight='bold')
396
+ ax7.set_title('Loss Components: WCE vs UFD\n(Train solid · Val dashed)', fontsize=13, fontweight='bold')
397
+ ax7.legend(fontsize=8)
398
+ ax7.grid(True, alpha=0.3)
399
+
400
+ # 8. Weighted contribution of each loss to the total loss
401
+ ax8 = fig.add_subplot(gs[2, 1])
402
+ beta_values = history.get('beta_value', [e / len(epochs) for e in epochs])
403
+ betas = np.array(beta_values)
404
+ ones = np.ones_like(betas)
405
+
406
+ # Weighted contributions
407
+ train_wce_contrib = (ones - betas) * np.array(history['wce_loss'])
408
+ train_ufd_contrib = betas * np.array(history['ufd_loss'])
409
+ val_wce_contrib = (ones - betas) * np.array(history['val_loss_wce'])
410
+ val_ufd_contrib = betas * np.array(history['val_loss_ufd'])
411
+
412
+ ax8.stackplot(list(epochs),
413
+ train_wce_contrib, train_ufd_contrib,
414
+ labels=['(1−β)·WCE [train] x10', 'β·UFD [train]'],
415
+ colors=[wce_color, ufd_color], alpha=0.55)
416
+ ax8.plot(epochs, history['train_loss'], 'k-', linewidth=1.5, label='Total Train Loss')
417
+
418
+ # Overlay val contributions as lines for clarity
419
+ ax8.plot(epochs, val_wce_contrib, '--', color=wce_color, linewidth=1.5,
420
+ alpha=0.8, label='(1−β)·WCE [val] x10')
421
+ ax8.plot(epochs, val_ufd_contrib, '--', color=ufd_color, linewidth=1.5,
422
+ alpha=0.8, label='β·UFD [val]')
423
+ if best_epoch:
424
+ ax8.axvline(x=best_epoch, color='red', linestyle='--', linewidth=2, alpha=0.7)
425
+ ax8.set_xlabel('Epoch', fontsize=11, fontweight='bold')
426
+ ax8.set_ylabel('Weighted Loss', fontsize=11, fontweight='bold')
427
+ ax8.set_title('Weighted Loss Contributions\n(Adaptive β Schedule)', fontsize=13, fontweight='bold')
428
+ ax8.legend(fontsize=8)
429
+ ax8.grid(True, alpha=0.3)
430
+
431
+ # # 9. Beta schedule
432
+ # ax9 = fig.add_subplot(gs[2, 2])
433
+ # ax9.plot(list(epochs), betas, 'o-', linewidth=2, markersize=5,
434
+ # color=beta_color, label='β (epoch/total)')
435
+ # ax9.fill_between(list(epochs), betas, alpha=0.15, color=beta_color)
436
+ # ax9.axhline(y=0.95, color='gray', linestyle=':', linewidth=1.5,
437
+ # label='β = 0.95 threshold')
438
+ # if best_epoch:
439
+ # ax9.axvline(x=best_epoch, color='red', linestyle='--', linewidth=2,
440
+ # alpha=0.7, label=f'Best Epoch ({best_epoch})')
441
+ # ax9.set_xlabel('Epoch', fontsize=11, fontweight='bold')
442
+ # ax9.set_ylabel('β value', fontsize=11, fontweight='bold')
443
+ # ax9.set_title('Beta Schedule\n(WCE → UFD transition)', fontsize=13, fontweight='bold')
444
+ # ax9.set_ylim([0, 1.05])
445
+ # ax9.legend(fontsize=9)
446
+ # ax9.grid(True, alpha=0.3)
447
+
448
+ # 6. Analysis Summary
449
+ ax6 = fig.add_subplot(gs[1, 2])
450
+ ax6.axis('off')
451
+
452
+ if analysis:
453
+ abnormal_class = class_names[analysis['abnormal_key']]
454
+ best_epoch_idx = analysis['best_overall_epoch'] - 1
455
+
456
+ # Get dice scores for all classes at the best epoch
457
+ best_epoch_metrics = history['val_metrics'][best_epoch_idx]['dice']
458
+
459
+ # Build dice scores text (excluding background)
460
+ dice_scores_text = ""
461
+ for i in range(1, num_classes):
462
+ class_key = f'class_{i}'
463
+ dice_value = best_epoch_metrics[class_key]
464
+ dice_scores_text += f" {class_names[class_key]}: {dice_value:.4f}\n"
465
+
466
+ summary_text = f"""
467
+ TRAINING ANALYSIS SUMMARY
468
+ {'=' * 40}
469
+
470
+ Model: a U-Net
471
+ Number of Classes: {analysis['num_classes']}
472
+ Total Epochs: {len(epochs)}
473
+
474
+ BEST OVERALL EPOCH: {analysis['best_overall_epoch']}
475
+ (Composite Score: {analysis['composite_score']:.4f})
476
+
477
+ Dice Scores at Best Epoch:
478
+ {dice_scores_text}
479
+ {'─' * 40}
480
+ Priority Metrics:
481
+ {'─' * 40}
482
+
483
+ Best {abnormal_class} Dice:
484
+ Epoch {analysis['best_abnormal_epoch']}: {analysis['best_abnormal_dice']:.4f}
485
+
486
+ Best Ventricles Dice:
487
+ Epoch {analysis['best_ventricles_epoch']}: {analysis['best_ventricles_dice']:.4f}
488
+
489
+ Best Validation Loss:
490
+ Epoch {analysis['best_val_loss_epoch']}: {analysis['best_val_loss']:.4f}
491
+
492
+ {'─' * 40}
493
+ Loss at Best Epoch:
494
+ Train WCE: {analysis['best_epoch_metrics']['wce_loss']:.4f}
495
+ Train UFD: {analysis['best_epoch_metrics']['ufd_loss']:.4f}"""
496
+
497
+ if analysis['best_epoch_metrics'].get('val_loss_wce') is not None:
498
+ summary_text += f"""
499
+ Val WCE: {analysis['best_epoch_metrics']['val_loss_wce']:.4f}
500
+ Val UFD: {analysis['best_epoch_metrics']['val_loss_ufd']:.4f}"""
501
+
502
+ summary_text += f"""
503
+ β value: {analysis['best_epoch_metrics']['beta_value']:.4f}
504
+
505
+ {'─' * 40}
506
+ Scoring Weights:
507
+ {abnormal_class}: 60%
508
+ Ventricles: 30%
509
+ Val Loss: 10%
510
+ """
511
+
512
+ ax6.text(0.05, 0.95, summary_text, transform=ax6.transAxes,
513
+ fontsize=9, verticalalignment='top', fontfamily='monospace',
514
+ bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.3))
515
+
516
+ plt.suptitle('a U-Net Training History - Comprehensive Analysis\n'
517
+ '(Adaptive Loss: WCE + UFD with β schedule)',
518
+ fontsize=16, fontweight='bold', y=0.998)
519
+
520
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
521
+ print(f"✓ Visualization saved to: {save_path}")
522
+ # plt.show()
523
+
524
+ return analysis
525
+
526
+ def print_detailed_analysis(analysis):
527
+ """Print detailed analysis to console."""
528
+ if not analysis:
529
+ print("No analysis available.")
530
+ return
531
+
532
+ print("\n" + "="*60)
533
+ print("DETAILED TRAINING ANALYSIS - a U-NET")
534
+ print("="*60)
535
+ print(f"\n📊 Number of Classes: {analysis['num_classes']}")
536
+ print(f"\n🏆 RECOMMENDED EPOCH: {analysis['best_overall_epoch']}")
537
+ print(f" Composite Score: {analysis['composite_score']:.4f}")
538
+ print("\n" + "-"*60)
539
+ print("Individual Best Performances:")
540
+ print("-"*60)
541
+ print(f"\n🎯 Abnormal WMH Dice (TOP PRIORITY):")
542
+ print(f" Best Epoch: {analysis['best_abnormal_epoch']}")
543
+ print(f" Best Score: {analysis['best_abnormal_dice']:.4f}")
544
+ print(f"\n🫀 Ventricles Dice (SECONDARY):")
545
+ print(f" Best Epoch: {analysis['best_ventricles_epoch']}")
546
+ print(f" Best Score: {analysis['best_ventricles_dice']:.4f}")
547
+ print(f"\n📉 Validation Loss (TERTIARY):")
548
+ print(f" Best Epoch: {analysis['best_val_loss_epoch']}")
549
+ print(f" Lowest Loss: {analysis['best_val_loss']:.4f}")
550
+ print("\n" + "="*60)
551
+ print("\nNote: Best overall epoch is calculated using weighted scoring:")
552
+ print(" • Abnormal WMH Dice: 60%")
553
+ print(" • Ventricles Dice: 30%")
554
+ print(" • Validation Loss: 10%")
555
+ print("="*60 + "\n")
556
+
557
+ def main_viz(filepath='history_sample.json', save_outputs=True):
558
+ """Main execution function."""
559
+ # Load history
560
+ print(f"Loading training history from: {filepath}")
561
+ history = load_history(filepath)
562
+
563
+ print(f"✓ Loaded {len(history['train_loss'])} epochs of training data")
564
+
565
+ # Get output directory
566
+ out_dir = os.path.dirname(filepath)
567
+
568
+ # Detect number of classes and get class names
569
+ num_classes = detect_num_classes(history)
570
+ class_names = get_class_names(num_classes)
571
+
572
+ # Find best epoch and create analysis
573
+ best_epoch, analysis = find_best_epoch(history, num_classes)
574
+
575
+ # Create visualization
576
+ plot_training_history(history, save_path=os.path.join(out_dir, 'a_unet_training_analysis.png'))
577
+
578
+ # Print detailed analysis
579
+ print_detailed_analysis(analysis)
580
+
581
+ if save_outputs:
582
+ print("\n" + "="*60)
583
+ print("SAVING ANALYSIS OUTPUTS")
584
+ print("="*60)
585
+
586
+ # 1. Save standalone analysis JSON
587
+ analysis_path = os.path.join(out_dir, 'best_epoch_analysis.json')
588
+ save_analysis_json(analysis, analysis_path)
589
+
590
+ # 2. Save enhanced history with analysis appended
591
+ enhanced_history_path = os.path.join(out_dir, 'history_with_analysis.json')
592
+ save_enhanced_history(history, analysis, enhanced_history_path)
593
+
594
+ # 3. Save training summary
595
+ summary = create_training_summary(history, analysis, class_names)
596
+ summary_path = os.path.join(out_dir, 'training_summary.json')
597
+ with open(summary_path, 'w') as f:
598
+ json.dump(summary, f, indent=2)
599
+ print(f"✓ Training summary saved to: {summary_path}")
600
+
601
+ print("\n" + "="*60)
602
+ print("ALL OUTPUTS SAVED SUCCESSFULLY")
603
+ print("="*60)
604
+ print("\nGenerated files:")
605
+ print(f" 1. unet_training_analysis.png - Visualization")
606
+ print(f" 2. best_epoch_analysis.json - Best epoch analysis")
607
+ print(f" 3. history_with_analysis.json - Enhanced history")
608
+ print(f" 4. training_summary.json - Comprehensive training summary")
609
+ print("="*60 + "\n")
610
+
611
+ return analysis, history
612
+
613
+ if __name__ == "__main__":
614
+
615
+ # experiment_dir = '/mnt/e/MBashiri/ours_articles/Paper#2/Development/results_unet_baseline_fold_0/models'
616
+ # scenario = 'standard_4class'
617
+ # fold_num = 'fold_0'
618
+ # filepath = os.path.join(experiment_dir, scenario, fold_num, 'history.json')
619
+
620
+ # main_viz(filepath=filepath, save_outputs=True)
621
+
622
+ for fold in range(5):
623
+
624
+ # Skip folds:
625
+ if fold in list(np.array([0, 2, 3, 4])):
626
+ continue
627
+
628
+ for variant in range(5):
629
+
630
+ # # Skip variants:
631
+ if variant not in list(np.array([1])):
632
+ continue
633
+
634
+ experiment_dir = f'/mnt/e/MBashiri/ours_articles/Paper#4/Development/results_fold_{fold}_var_{variant}_zscore2/models'
635
+ scenario = 'standard_3class'
636
+ fold_num = f'fold_{fold}'
637
+ filepath = os.path.join(experiment_dir, scenario, fold_num, 'history.json')
638
+
639
+ main_viz(filepath=filepath)
640
+
models/for_WMH_Vent/model_training_scripts/p4_variant_all_net.py ADDED
@@ -0,0 +1,1051 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ P4 - All U-Net models with Adaptive Loss (WCE + UFL)
3
+
4
+ WMH and Ventricles Segmentation with U-Net Models - Journal Paper Implementation
5
+ Three-class segmentation: Background vs Ventricles vs Abnormal WMH
6
+ Professional results saving and visualization for publication
7
+
8
+ This relates to our article:
9
+ "Deep Learning-Based Neuroanatomical Profiling Reveals Detailed Brain Changes:
10
+ A Large-Scale Multiple Sclerosis Study"
11
+
12
+ Features:
13
+ - Various U-Net architecture
14
+ - Weighted Categorical Cross-Entropy loss
15
+ - Unified Focal loss
16
+ - One-hot encoded targets
17
+ - Class weight computation per fold
18
+
19
+ Authors:
20
+ "Mahdi Bashiri Bawil, Mousa Shamsi, Abolhassan Shakeri Bavil"
21
+
22
+ Developer:
23
+ "Mahdi Bashiri Bawil"
24
+ """
25
+
26
+ import tensorflow as tf
27
+ import os
28
+ import time
29
+ import numpy as np
30
+ import matplotlib.pyplot as plt
31
+ from pathlib import Path
32
+ from tqdm import tqdm
33
+ import json
34
+
35
+ # Import data loader
36
+ from p4_data_loader import DataConfig, P2DataLoader
37
+
38
+ # Import utilities from baseline
39
+ from utility_functions import (
40
+ clear_gpu_memory,
41
+ get_gpu_memory_info,
42
+ )
43
+
44
+ # Import class weights utility
45
+ from p4_compute_class_weights import compute_and_save_class_weights, load_class_weights
46
+
47
+ print("TensorFlow Version:", tf.__version__)
48
+
49
+ ###################### GPU Configuration ######################
50
+
51
+ # Configure GPU memory growth
52
+ physical_devices = tf.config.list_physical_devices('GPU')
53
+ if physical_devices:
54
+ try:
55
+ for device in physical_devices:
56
+ tf.config.experimental.set_memory_growth(device, True)
57
+ print("✅ GPU memory growth enabled")
58
+ print(f" Available GPUs: {len(physical_devices)}")
59
+ except RuntimeError as e:
60
+ print(f"GPU configuration error: {e}")
61
+ else:
62
+ print("⚠️ No GPU detected - training will be slow")
63
+
64
+ """
65
+ GPU Memory Management for Sequential Experiments
66
+ To properly release memory between experiments
67
+ """
68
+
69
+ ###################### Target Preparation ######################
70
+
71
+ def prepare_inputs(paired_input, target_mask, num_classes):
72
+ """
73
+ Prepare inputs for training
74
+
75
+ Args:
76
+ paired_input: (bs, 256, 512, 1) with FLAIR + mask
77
+ target_mask: (bs, 256, 256) with class labels [0, num_classes-1]
78
+ num_classes: number of classes
79
+
80
+ Returns:
81
+ flair_normalized: FLAIR normalized to [-1, 1]
82
+ target_onehot: One-hot encoded mask (bs, 256, 256, num_classes)
83
+ """
84
+ # Extract FLAIR, previously normalized to [-1, 1]
85
+ flair_normalized = paired_input[:, :, :256, :]
86
+
87
+ # One-hot encode target
88
+ target_onehot = tf.one_hot(target_mask, depth=num_classes, dtype=tf.float32)
89
+
90
+ return flair_normalized, target_onehot
91
+
92
+ ###################### Metrics Calculation ######################
93
+
94
+ def compute_classwise_metrics(all_val_true, all_val_pred, num_classes, exclude_class=None):
95
+ """
96
+ Compute class-wise Dice, Precision, and Recall for validation predictions.
97
+
98
+ Args:
99
+ all_val_true: List of one-hot encoded ground truth tensors
100
+ all_val_pred: List of softmax output tensors from model
101
+ num_classes: Number of classes (3 or 4)
102
+ exclude_class: Class to exclude from metric calculation (e.g., 2 for background)
103
+
104
+ Returns:
105
+ Dictionary containing class-wise and mean metrics
106
+ """
107
+ # Concatenate all batches
108
+ y_true_concat = tf.concat(all_val_true, axis=0) # Shape: (N, H, W, num_classes)
109
+ y_pred_concat = tf.concat(all_val_pred, axis=0) # Shape: (N, H, W, num_classes)
110
+
111
+ # Flatten spatial dimensions: (N*H*W, num_classes)
112
+ y_true_flat = tf.reshape(y_true_concat, [-1, num_classes])
113
+ y_pred_flat = tf.reshape(y_pred_concat, [-1, num_classes])
114
+
115
+ # Convert predictions to one-hot (argmax)
116
+ y_pred_classes = tf.argmax(y_pred_flat, axis=-1)
117
+ y_pred_onehot = tf.one_hot(y_pred_classes, depth=num_classes)
118
+
119
+ # Convert to numpy for easier computation
120
+ y_true_np = y_true_flat.numpy()
121
+ y_pred_np = y_pred_onehot.numpy()
122
+
123
+ metrics = {
124
+ 'dice': {},
125
+ 'precision': {},
126
+ 'recall': {}
127
+ }
128
+
129
+ classes_to_evaluate = [c for c in range(num_classes) if c != exclude_class]
130
+
131
+ for class_idx in classes_to_evaluate:
132
+ # Extract binary masks for this class
133
+ true_class = y_true_np[:, class_idx]
134
+ pred_class = y_pred_np[:, class_idx]
135
+
136
+ # True Positives, False Positives, False Negatives
137
+ TP = np.sum((true_class == 1) & (pred_class == 1))
138
+ FP = np.sum((true_class == 0) & (pred_class == 1))
139
+ FN = np.sum((true_class == 1) & (pred_class == 0))
140
+
141
+ # Dice Score: 2*TP / (2*TP + FP + FN)
142
+ dice = (2 * TP) / (2 * TP + FP + FN + 1e-7)
143
+
144
+ # Precision: TP / (TP + FP)
145
+ precision = TP / (TP + FP + 1e-7)
146
+
147
+ # Recall (Sensitivity): TP / (TP + FN)
148
+ recall = TP / (TP + FN + 1e-7)
149
+
150
+ metrics['dice'][f'class_{class_idx}'] = float(dice)
151
+ metrics['precision'][f'class_{class_idx}'] = float(precision)
152
+ metrics['recall'][f'class_{class_idx}'] = float(recall)
153
+
154
+ # Compute mean metrics (excluding the excluded class)
155
+ metrics['dice']['mean'] = np.mean([v for v in metrics['dice'].values()])
156
+ metrics['precision']['mean'] = np.mean([v for v in metrics['precision'].values()])
157
+ metrics['recall']['mean'] = np.mean([v for v in metrics['recall'].values()])
158
+
159
+ return metrics
160
+
161
+ ###################### Experiment Configuration ######################
162
+
163
+ class ExperimentConfig:
164
+ """Configuration for a Specific U-Net experiment"""
165
+
166
+ def __init__(self,
167
+ variant: int = 1,
168
+ preprocessing: str = 'standard',
169
+ class_scenario: str = '3class',
170
+ fold_id: int = 0,
171
+ architecture_name: str = 'unet'
172
+ ):
173
+
174
+ # Experiment identification
175
+ self.variant = variant
176
+ self.preprocessing = preprocessing # 'standard' or 'zoomed'
177
+ self.class_scenario = class_scenario # '3class' or '4class'
178
+ self.fold_id = fold_id
179
+ self.architecture_name = architecture_name
180
+
181
+ # Experiment name
182
+ self.exp_name = f"exp_{architecture_name}_{preprocessing}_{class_scenario}_fold{fold_id}"
183
+
184
+ # Number of classes
185
+ self.num_classes = 3 if class_scenario == '3class' else 4
186
+
187
+ # Training hyperparameters
188
+ self.batch_size = 4
189
+ self.img_width = 256
190
+ self.img_height = 256
191
+ self.epochs = 60
192
+
193
+ # Optimizer parameters
194
+ self.learning_rate = 2e-4
195
+ self.beta_1 = 0.9
196
+
197
+ # Adaptive loss parameters
198
+ self.focal_gamma = 0.5 # Focal loss focusing parameter
199
+ self.beta_threshold = 0.25 # Transition at epoch 15/60
200
+ self.beta_smoothness = 0.02 # Transition width
201
+ self.use_focal_alpha = True # Use class weights in focal loss
202
+
203
+ # ReduceLROnPlateau parameters
204
+ self.lr_patience = 5 # Wait 5 epochs before reducing
205
+ self.lr_reduction_factor = 0.5 # Reduce LR by half
206
+ self.lr_min = 1e-7 # Don't go below this
207
+ self.lr_monitor = 'val_loss' # Or 'val_dice_mean'
208
+
209
+ # Paths
210
+ self.results_dir = Path(f"results_fold_{fold_id}_var_{variant}_zscore3")
211
+ self.models_dir = self.results_dir / "models" / f"{preprocessing}_{class_scenario}"
212
+ self.figures_dir = self.results_dir / "figures" / f"{preprocessing}_{class_scenario}" / f"fold_{fold_id}"
213
+ self.logs_dir = self.results_dir / "logs" / f"{preprocessing}_{class_scenario}" / f"fold_{fold_id}"
214
+
215
+ # Create directories
216
+ self.models_dir.mkdir(parents=True, exist_ok=True)
217
+ self.figures_dir.mkdir(parents=True, exist_ok=True)
218
+ self.logs_dir.mkdir(parents=True, exist_ok=True)
219
+
220
+ # Checkpoint configuration
221
+ self.checkpoint_dir = self.models_dir / f"fold_{fold_id}"
222
+ self.checkpoint_dir.mkdir(exist_ok=True)
223
+
224
+ # Class weights directory
225
+ self.weights_dir = Path("class_weights")
226
+ self.weights_dir.mkdir(exist_ok=True)
227
+
228
+ # Save configuration
229
+ self.save_config()
230
+
231
+ def save_config(self):
232
+ """Save experiment configuration to JSON"""
233
+ config_dict = {
234
+ 'variant': self.variant,
235
+ 'variant_name': f'{self.architecture_name}',
236
+ 'preprocessing': self.preprocessing,
237
+ 'class_scenario': self.class_scenario,
238
+ 'fold_id': self.fold_id,
239
+ 'num_classes': self.num_classes,
240
+ 'batch_size': self.batch_size,
241
+ 'epochs': self.epochs,
242
+ 'focal_gamma': self.focal_gamma,
243
+ 'beta_threshold': self.beta_threshold,
244
+ 'beta_smoothness': self.beta_smoothness,
245
+ 'learning_rate': self.learning_rate,
246
+ 'beta_1': self.beta_1,
247
+ 'loss': 'Phase-transitioning segmentation loss (WCE → UFD)'
248
+ }
249
+
250
+ config_file = self.checkpoint_dir / "config.json"
251
+ with open(config_file, 'w') as f:
252
+ json.dump(config_dict, f, indent=2)
253
+
254
+
255
+ ###################### Beta Scheduling ######################
256
+
257
+ def smooth_step(x, threshold=0.5, smoothness=0.1):
258
+ """
259
+ Smooth step function for phase transition
260
+
261
+ Creates smooth transition around threshold value using sigmoid.
262
+
263
+ Args:
264
+ x: Current progress (typically epoch / total_epochs)
265
+ threshold: Center point of transition (e.g., 0.5 for epoch 25/50)
266
+ smoothness: Width of transition (smaller = sharper, larger = smoother)
267
+
268
+ Returns:
269
+ Value in [0, 1] representing transition progress
270
+ - x << threshold: returns ≈ 0
271
+ - x ≈ threshold: returns ≈ 0.5
272
+ - x >> threshold: returns ≈ 1
273
+
274
+ Example:
275
+ epoch_progress = 0.3 # Epoch 15/50
276
+ beta = smooth_step(0.3, threshold=0.5, smoothness=0.1)
277
+ # beta ≈ 0.05 (mostly phase 1)
278
+
279
+ epoch_progress = 0.5 # Epoch 25/50
280
+ beta = smooth_step(0.5, threshold=0.5, smoothness=0.1)
281
+ # beta ≈ 0.5 (equal mix)
282
+
283
+ epoch_progress = 0.7 # Epoch 35/50
284
+ beta = smooth_step(0.7, threshold=0.5, smoothness=0.1)
285
+ # beta ≈ 0.95 (mostly phase 2)
286
+ """
287
+ # Sigmoid centered at threshold
288
+ # (x - threshold) / smoothness controls steepness
289
+ return tf.sigmoid((x - threshold) / smoothness)
290
+
291
+
292
+ def compute_beta_schedule(current_epoch, total_epochs,
293
+ threshold=0.5, smoothness=0.1):
294
+ """
295
+ Compute beta value for current epoch
296
+
297
+ Args:
298
+ current_epoch: Current epoch number (0-indexed)
299
+ total_epochs: Total number of epochs
300
+ threshold: Transition center (0.5 = midpoint)
301
+ smoothness: Transition width
302
+
303
+ Returns:
304
+ Beta value in [0, 1]
305
+ """
306
+ epoch_progress = tf.cast(current_epoch, tf.float32) / tf.cast(total_epochs, tf.float32)
307
+ beta = smooth_step(epoch_progress, threshold, smoothness)
308
+ return beta
309
+
310
+ ###################### Loss Functions ######################
311
+
312
+ def unified_focal_loss(y_true, y_pred, gamma=2.0, alpha=None, exclude_class=None):
313
+ """
314
+ Unified Focal Loss
315
+
316
+ Focal loss down-weights easy examples and focuses on hard examples.
317
+ Particularly effective for class imbalance and boundary regions.
318
+
319
+ Args:
320
+ y_true: Ground truth labels (bs, H, W, num_classes) one-hot encoded
321
+ y_pred: Predicted probabilities (bs, H, W, num_classes) from softmax
322
+ gamma: Focusing parameter (default 2.0)
323
+ - gamma=0: equivalent to cross-entropy
324
+ - gamma>0: down-weights easy examples
325
+ - Higher gamma = more focus on hard examples
326
+ alpha: Per-class balancing weights (num_classes,) - optional, trainable
327
+ - If None, no additional balancing
328
+ - If provided, applies per-class weighting like weighted CE
329
+
330
+ Returns:
331
+ Scalar loss value
332
+
333
+ Formula:
334
+ FL = -α * (1 - p_t)^γ * log(p_t)
335
+ where:
336
+ - p_t is probability of correct class
337
+ - (1 - p_t)^γ is modulating factor (focal term)
338
+ - α is class balancing weight
339
+ """
340
+ # Clip predictions to avoid log(0)
341
+ y_pred = tf.clip_by_value(y_pred, 1e-7, 1.0 - 1e-7)
342
+
343
+ # Probability of correct class at each pixel
344
+ # y_true is one-hot, so this extracts p for the true class
345
+ p_t = tf.reduce_sum(y_true * y_pred, axis=-1)
346
+ # Shape: (bs, H, W)
347
+
348
+ # Focal term: (1 - p_t)^gamma
349
+ # This is small for easy examples (p_t ≈ 1) and large for hard examples (p_t ≈ 0)
350
+ focal_term = tf.pow(1.0 - p_t, gamma)
351
+ # Shape: (bs, H, W)
352
+
353
+ # Cross-entropy term: -log(p_t)
354
+ ce_term = -tf.math.log(p_t)
355
+ # Shape: (bs, H, W)
356
+
357
+ # Focal loss: focal_term * ce_term
358
+ focal_loss = focal_term * ce_term
359
+ # Shape: (bs, H, W)
360
+
361
+ # Optional: Apply alpha balancing (per-class weights)
362
+ if alpha is not None:
363
+ # Get weight for true class at each pixel
364
+ weights_tensor = tf.cast(alpha, dtype=tf.float32)
365
+ weights_tensor = tf.reshape(weights_tensor, [1, 1, 1, -1])
366
+ alpha_map = tf.reduce_sum(y_true * weights_tensor, axis=-1)
367
+ # Shape: (bs, H, W)
368
+
369
+ # Weighted focal
370
+ # Exclude specific class if specified
371
+ if exclude_class is not None:
372
+ class_mask = tf.argmax(y_true, axis=-1) # (bs, 256, 256)
373
+ valid_mask = tf.cast(class_mask != exclude_class, tf.float32)
374
+
375
+ if alpha is not None:
376
+ focal_loss = alpha_map * focal_loss * valid_mask
377
+ else:
378
+ focal_loss = focal_loss * valid_mask
379
+
380
+ return tf.reduce_sum(focal_loss) / (tf.reduce_sum(valid_mask) + 1e-7)
381
+ else:
382
+
383
+ if alpha is not None:
384
+ focal_loss = alpha_map * focal_loss
385
+
386
+ return tf.reduce_mean(focal_loss)
387
+
388
+
389
+ def unified_focal_dice_loss(y_true, y_pred, gamma=0.5, delta=0.6, alpha=None, exclude_class=None):
390
+ """
391
+ Unified Focal Loss - Dice-based
392
+
393
+ Combines Dice coefficient with precision-recall focal weighting.
394
+ Best for imbalanced multi-class segmentation with small structures.
395
+
396
+ Args:
397
+ y_true: Ground truth one-hot (bs, H, W, num_classes)
398
+ y_pred: Predicted probabilities (bs, H, W, num_classes)
399
+ gamma: Focusing parameter for Dice component (default 0.5)
400
+ - gamma=0: equivalent to Dice loss
401
+ - gamma>0: focuses on hard examples
402
+ delta: Weight for precision-recall component (0-1, default 0.6)
403
+ - Controls emphasis on boundary regions
404
+ alpha: Per-class weights (num_classes,) - optional
405
+ exclude_class: Class index to exclude from loss
406
+
407
+ Returns:
408
+ Scalar loss value
409
+
410
+ Formula:
411
+ UFL = (1 - Dice)^gamma * (1 - precision * recall)^delta
412
+ Focuses on hard examples and boundary regions
413
+ """
414
+ smooth = 1e-6
415
+ y_pred = tf.clip_by_value(y_pred, 1e-7, 1.0 - 1e-7)
416
+ num_classes = tf.shape(y_pred)[-1]
417
+
418
+ unified_losses = []
419
+
420
+ for class_idx in range(num_classes if isinstance(num_classes, int) else y_pred.shape[-1]):
421
+ # Skip excluded class
422
+ if exclude_class is not None and class_idx == exclude_class:
423
+ continue
424
+
425
+ y_true_class = y_true[..., class_idx]
426
+ y_pred_class = y_pred[..., class_idx]
427
+
428
+ # Flatten for calculations
429
+ y_true_f = tf.reshape(y_true_class, [-1])
430
+ y_pred_f = tf.reshape(y_pred_class, [-1])
431
+
432
+ # True positives, false positives, false negatives
433
+ tp = tf.reduce_sum(y_true_f * y_pred_f)
434
+ fp = tf.reduce_sum((1.0 - y_true_f) * y_pred_f)
435
+ fn = tf.reduce_sum(y_true_f * (1.0 - y_pred_f))
436
+
437
+ # Precision and recall
438
+ precision = (tp + smooth) / (tp + fp + smooth)
439
+ recall = (tp + smooth) / (tp + fn + smooth)
440
+
441
+ # Dice coefficient
442
+ dice = (2.0 * tp + smooth) / (2.0 * tp + fp + fn + smooth)
443
+
444
+ # Unified focal loss: focuses on hard examples and boundary regions
445
+ # (1 - dice)^gamma: focuses on classes with low Dice (hard examples)
446
+ # (1 - precision * recall)^delta: focuses on boundary regions
447
+ unified_loss_class = tf.pow(1.0 - dice, gamma) * tf.pow(1.0 - precision * recall, delta)
448
+
449
+ # Apply class weights
450
+ if alpha is not None:
451
+ unified_loss_class = unified_loss_class * tf.cast(alpha[class_idx], tf.float32)
452
+
453
+ unified_losses.append(unified_loss_class)
454
+
455
+ # Stack and mean across classes (excluding the skipped class)
456
+ total_loss = tf.reduce_mean(tf.stack(unified_losses))
457
+
458
+ return total_loss
459
+
460
+
461
+ def weighted_categorical_crossentropy(y_true, y_pred, class_weights, exclude_class=None):
462
+ """
463
+ Weighted categorical cross-entropy loss
464
+
465
+ Args:
466
+ y_true: (bs, 256, 256, num_classes) one-hot encoded
467
+ y_pred: (bs, 256, 256, num_classes) softmax probabilities
468
+ class_weights: (num_classes,) weight per class
469
+ exclude_class: Optional int, class index to exclude from loss (e.g., 2 for CSF)
470
+
471
+ Returns:
472
+ Scalar loss value
473
+ """
474
+ # Clip predictions to prevent log(0)
475
+ y_pred = tf.clip_by_value(y_pred, 1e-7, 1.0 - 1e-7)
476
+
477
+ # Cross-entropy per pixel: -sum(y_true * log(y_pred))
478
+ ce = -tf.reduce_sum(y_true * tf.math.log(y_pred), axis=-1) # (bs, 256, 256)
479
+
480
+ # Apply class weights
481
+ # class_weights shape: (num_classes,) -> (1, 1, 1, num_classes) for broadcasting
482
+ weights_tensor = tf.cast(class_weights, dtype=tf.float32)
483
+ weights_tensor = tf.reshape(weights_tensor, [1, 1, 1, -1])
484
+
485
+ # Weight map: (bs, 256, 256)
486
+ pixel_weights = tf.reduce_sum(y_true * weights_tensor, axis=-1)
487
+
488
+ # Weighted cross-entropy
489
+ # Exclude specific class if specified
490
+ if exclude_class is not None:
491
+ class_mask = tf.argmax(y_true, axis=-1) # (bs, 256, 256)
492
+ valid_mask = tf.cast(class_mask != exclude_class, tf.float32)
493
+ weighted_ce = ce * pixel_weights * valid_mask
494
+ return tf.reduce_sum(weighted_ce) / (tf.reduce_sum(valid_mask) + 1e-7)
495
+ else:
496
+ weighted_ce = ce * pixel_weights
497
+ return tf.reduce_mean(weighted_ce)
498
+
499
+
500
+ def adaptive_segmentation_loss(y_true, y_pred, class_weights, beta,
501
+ focal_gamma=0.5, use_focal_alpha=True,
502
+ exclude_class=None):
503
+ """
504
+ Adaptive segmentation loss with hard phase transition
505
+
506
+ Combines weighted cross-entropy (phase 1) and focal loss (phase 2)
507
+ based on epoch progress (beta).
508
+
509
+ Args:
510
+ y_true: Ground truth (bs, H, W, num_classes) one-hot
511
+ y_pred: Predictions (bs, H, W, num_classes) softmax probabilities
512
+ class_weights: Trainable class weights (num_classes,)
513
+ beta: Transition parameter [0, 1]
514
+ - beta=0: pure weighted CE (early training)
515
+ - beta=1: pure focal loss (late training)
516
+ focal_gamma: Focusing parameter for focal loss (default 0.5)
517
+ use_focal_alpha: Whether to use class_weights as focal alpha
518
+
519
+ Returns:
520
+ seg_loss: Final loss
521
+ wcce_loss: Weighted CE component (for monitoring)
522
+ focal_loss: Focal loss component (for monitoring)
523
+
524
+ Phase Behavior:
525
+ Epochs 1-10: beta ≈ 0 → Weighted CE dominates
526
+ - Learns basic class separation
527
+ - Benefits from explicit class weighting
528
+
529
+ Epochs 10-20: beta transitions 0 → 1
530
+ - Smooth change in loss landscape
531
+ - Gradual shift in training dynamics
532
+
533
+ Epochs 20-60: beta ≈ 1 → Focal loss dominates
534
+ - Focuses on hard examples
535
+ - Refines boundaries and difficult regions
536
+ """
537
+ # Compute Phase 1 loss: Weighted Cross-Entropy
538
+ wcce_loss = 10 * weighted_categorical_crossentropy(y_true, y_pred, class_weights, exclude_class=exclude_class)
539
+
540
+ # Compute Phase 2 loss: Focal Loss
541
+ focal_alpha = class_weights if use_focal_alpha else None
542
+ focal_loss = unified_focal_dice_loss(y_true, y_pred,
543
+ gamma=focal_gamma,
544
+ alpha=focal_alpha,
545
+ exclude_class=exclude_class)
546
+
547
+ # Adaptive combination based on beta
548
+ # beta=0: (1-0)*wce + 0*focal = wce (phase 1)
549
+ # beta=1: (1-1)*wce + 1*focal = focal (phase 2)
550
+ # beta=0.5: 0.5*wce + 0.5*focal = equal mix (transition)
551
+ seg_loss = (1.0 - beta) * wcce_loss + beta * focal_loss
552
+
553
+ return seg_loss, wcce_loss, focal_loss
554
+
555
+ ###################### Training Functions ######################
556
+
557
+ @tf.function
558
+ def train_step(input_image, target_onehot, model, optimizer,
559
+ class_weights, beta, focal_gamma,
560
+ use_focal_alpha=True, exclude_class=None):
561
+ """
562
+ Single training step for U-Net
563
+
564
+ Args:
565
+ input_image: Input FLAIR (bs, 256, 256, 1) in [-1, 1]
566
+ target_onehot: Target mask (bs, 256, 256, num_classes) one-hot
567
+ model: a specific U-Net model
568
+ optimizer: Optimizer
569
+ class_weights: (num_classes,) weight per class
570
+ beta: Current beta for phase transition
571
+
572
+
573
+ Returns:
574
+ loss: Training loss value
575
+ """
576
+ with tf.GradientTape() as tape:
577
+ # Forward pass
578
+ predictions = model(input_image, training=True)
579
+
580
+ # Compute loss
581
+ seg_loss, wcce_loss, focal_loss = adaptive_segmentation_loss(target_onehot, predictions, class_weights,
582
+ beta, focal_gamma, use_focal_alpha, exclude_class)
583
+
584
+ # Calculate gradients
585
+ gradients = tape.gradient(seg_loss, model.trainable_variables)
586
+
587
+ # Apply gradients
588
+ optimizer.apply_gradients(zip(gradients, model.trainable_variables))
589
+
590
+ return seg_loss, wcce_loss, focal_loss
591
+
592
+ def generate_and_save_images(model, test_input, test_target,
593
+ epoch, save_path, num_classes):
594
+ """
595
+ Generate predictions and save visualization
596
+
597
+ Args:
598
+ model: a specific U-Net model
599
+ test_input: Test input image (bs, 256, 512, 1)
600
+ test_target: Test target mask (bs, 256, 256)
601
+ epoch: Current epoch number
602
+ save_path: Path to save figure
603
+ num_classes: Number of classes
604
+ """
605
+ for ik in range(test_input.numpy().shape[0]):
606
+ # Extract FLAIR
607
+ flair_normalized = test_input[ik, :, :256, :]
608
+ flair_normalized = tf.expand_dims(flair_normalized, axis=0)
609
+
610
+ # Generate prediction
611
+ prediction_softmax = model(flair_normalized, training=False)
612
+
613
+ # Convert to class labels
614
+ pred_classes = tf.argmax(prediction_softmax, axis=-1).numpy()
615
+ target_mask = test_target[ik].numpy()
616
+
617
+ # Create figure
618
+ plt.figure(figsize=(20, 5))
619
+
620
+ # Input FLAIR
621
+ plt.subplot(1, 5, 1)
622
+ plt.title('Input FLAIR')
623
+ plt.imshow(flair_normalized[0, :, :, 0], cmap='gray')
624
+ plt.axis('off')
625
+
626
+ # Ground truth
627
+ plt.subplot(1, 5, 2)
628
+ plt.title('Ground Truth')
629
+ plt.imshow(target_mask, cmap='jet', vmin=0, vmax=num_classes-1)
630
+ plt.colorbar()
631
+ plt.axis('off')
632
+
633
+ # Prediction
634
+ plt.subplot(1, 5, 3)
635
+ plt.title('Predicted Classes')
636
+ plt.imshow(pred_classes[0], cmap='jet', vmin=0, vmax=num_classes-1)
637
+ plt.colorbar()
638
+ plt.axis('off')
639
+
640
+ # Class probabilities for most confident prediction
641
+ plt.subplot(1, 5, 4)
642
+ plt.title('Max Probability')
643
+ max_prob = tf.reduce_max(prediction_softmax[0], axis=-1).numpy()
644
+ plt.imshow(max_prob, cmap='viridis', vmin=0, vmax=1)
645
+ plt.colorbar()
646
+ plt.axis('off')
647
+
648
+ # Difference map
649
+ plt.subplot(1, 5, 5)
650
+ plt.title('Error Map (Red=Wrong)')
651
+ error_map = (pred_classes[0] != target_mask).astype(float)
652
+ plt.imshow(error_map, cmap='Reds', vmin=0, vmax=1)
653
+ plt.colorbar()
654
+ plt.axis('off')
655
+
656
+ plt.tight_layout()
657
+ plt.savefig(save_path / f'epoch_{epoch:03d}_{ik+1}.png', dpi=300, bbox_inches='tight')
658
+ plt.close()
659
+
660
+ ###################### Main Training Function ######################
661
+
662
+ def train_net(config: ExperimentConfig):
663
+ """
664
+ Main training function for a Specific U-Net
665
+
666
+ Args:
667
+ config: ExperimentConfig object
668
+ """
669
+ print("\n" + "="*70)
670
+ print(f"TRAINING {config.architecture_name}: {config.exp_name}")
671
+ print("="*70)
672
+ print(f"Variant: {config.variant}")
673
+ print(f"Preprocessing: {config.preprocessing}")
674
+ print(f"Class scenario: {config.class_scenario} ({config.num_classes} classes)")
675
+ print(f"Fold: {config.fold_id}")
676
+ print(f"Epochs: {config.epochs}")
677
+ print(f"Batch size: {config.batch_size}")
678
+ print(f"Loss: Weighted Categorical Cross-Entropy → Unified Focal")
679
+ print("="*70 + "\n")
680
+
681
+ # Check initial GPU memory
682
+ get_gpu_memory_info()
683
+
684
+ # Initialize data loader
685
+ data_config = DataConfig()
686
+ data_loader = P2DataLoader(data_config)
687
+
688
+ # Load datasets
689
+ print("Loading training data...")
690
+ train_dataset = data_loader.create_dataset_for_fold(
691
+ fold_id=config.fold_id,
692
+ split='train',
693
+ preprocessing=config.preprocessing,
694
+ class_scenario=config.class_scenario,
695
+ batch_size=config.batch_size,
696
+ shuffle=True
697
+ )
698
+
699
+ print("Loading validation data...")
700
+ val_dataset = data_loader.create_dataset_for_fold(
701
+ fold_id=config.fold_id,
702
+ split='val',
703
+ preprocessing=config.preprocessing,
704
+ class_scenario=config.class_scenario,
705
+ batch_size=config.batch_size,
706
+ shuffle=False
707
+ )
708
+
709
+ # Get dataset sizes
710
+ # Note: from_generator pipelines always report cardinality as INFINITE (-1)
711
+ # even with .cache(), so we derive the batch count from the slice list instead.
712
+ # We iterate once here; this also warms the in-memory cache so epoch 1 is fast.
713
+ print("Warming dataset cache (first pass over data — subsequent epochs use RAM)...")
714
+ train_size = sum(1 for _ in train_dataset)
715
+ val_size = sum(1 for _ in val_dataset)
716
+ # ⚠️ Do NOT rebuild the datasets here — that would create new generators and
717
+ # throw away the cache we just populated.
718
+
719
+ print(f"Training samples (batches): {train_size}")
720
+ print(f"Validation samples (batches): {val_size}\n")
721
+
722
+ # Compute or load class weights
723
+ print("Computing class weights from training data...")
724
+ try:
725
+ class_weights = load_class_weights(
726
+ config.fold_id, config.class_scenario,
727
+ config.preprocessing, config.weights_dir
728
+ )
729
+ print("✅ Loaded pre-computed class weights")
730
+ except FileNotFoundError:
731
+ print("Computing class weights (this may take a few minutes)...")
732
+ results = compute_and_save_class_weights(
733
+ config.fold_id, config.class_scenario,
734
+ config.preprocessing, str(config.weights_dir)
735
+ )
736
+ class_weights = np.array(results['class_weights'], dtype=np.float32)
737
+
738
+ print(f"Class weights: {class_weights}")
739
+
740
+ # Build model
741
+ print(f"\n🏗️ Building {config.architecture_name} model...")
742
+
743
+ if config.architecture_name == 'unet':
744
+ from unet_model import build_unet_3class as build_specific_3class # must be updated with the actual used model for traininig
745
+ elif config.architecture_name == 'attnunet':
746
+ from attn_unet_model import build_attention_unet_3class as build_specific_3class
747
+ elif config.architecture_name == 'dlv3unet':
748
+ from dlv3_unet_model_GN import build_deeplabv3_unet_3class as build_specific_3class
749
+ elif config.architecture_name == 'transunet':
750
+ from trans_unet_model import build_trans_unet_3class as build_specific_3class
751
+ else:
752
+ print(f"❌ Error loading model: Invalid Model Name")
753
+ raise
754
+
755
+ model = build_specific_3class(input_shape=(256, 256, 1), num_classes=config.num_classes)
756
+
757
+ print(f"Model parameters: {model.count_params():,}\n")
758
+
759
+ # Optimizer (will be updated with ReduceLROnPlateau)
760
+ optimizer = tf.keras.optimizers.legacy.Adam(
761
+ config.learning_rate, beta_1=config.beta_1
762
+ )
763
+
764
+ # Initialize optimizer variables
765
+ print("Initializing optimizer variables...")
766
+ dummy_input = tf.zeros((1, 256, 256, 1))
767
+
768
+ with tf.GradientTape() as tape:
769
+ output = model(dummy_input, training=True)
770
+ dummy_loss = tf.reduce_mean(output)
771
+
772
+ # Apply dummy gradients to build optimizer variables
773
+ grads = tape.gradient(dummy_loss, model.trainable_variables)
774
+ optimizer.apply_gradients(zip(grads, model.trainable_variables))
775
+ print("✅ Optimizer variables initialized\n")
776
+
777
+ # Checkpoint
778
+ checkpoint = tf.train.Checkpoint(
779
+ optimizer=optimizer,
780
+ model=model
781
+ )
782
+
783
+ checkpoint_prefix = config.checkpoint_dir / "ckpt"
784
+ manager = tf.train.CheckpointManager(
785
+ checkpoint, config.checkpoint_dir, max_to_keep=1
786
+ )
787
+
788
+ if manager.latest_checkpoint:
789
+ checkpoint.restore(manager.latest_checkpoint)
790
+ print(f"✅ Restored from checkpoint: {manager.latest_checkpoint}\n")
791
+ else:
792
+ print("Starting training from scratch\n")
793
+
794
+ # Get example for visualization
795
+ skip_n = 1 # min(100 // config.batch_size, val_size - 1)
796
+ example_paired, example_target, _, _ = next(iter(val_dataset.skip(skip_n).take(20)))
797
+
798
+ print("Initializing metrics computer...")
799
+ if config.num_classes == 4:
800
+ class_names = ['Background', 'Ventricles', 'Normal_WMH', 'Abnormal_WMH']
801
+ elif config.num_classes == 3:
802
+ class_names = ['Background', 'Ventricles', 'Abnormal_WMH']
803
+
804
+ # Training history
805
+ history = {
806
+ 'train_loss': [],
807
+ 'wce_loss': [],
808
+ 'ufd_loss': [],
809
+ 'val_loss': [],
810
+ 'val_loss_wce': [],
811
+ 'val_loss_ufd': [],
812
+ 'val_metrics': [],
813
+ 'beta_value': []
814
+ }
815
+
816
+ # Training loop
817
+ best_val_loss = float('inf')
818
+ best_val_dice = float('-inf')
819
+ exclude_class = 2 if config.num_classes == 4 else None # Exclude class 2 only in 4-class
820
+
821
+ try:
822
+ for epoch in range(config.epochs):
823
+ start_time = time.time()
824
+
825
+ # Compute beta for this epoch
826
+ beta_value = compute_beta_schedule(
827
+ epoch, config.epochs,
828
+ config.beta_threshold, config.beta_smoothness
829
+ )
830
+
831
+ # Training metrics
832
+ epoch_losses = []
833
+ epoch_loss_wce = []
834
+ epoch_loss_ufd = []
835
+
836
+ # Training loop
837
+
838
+ # Update learning rate based on epoch
839
+
840
+ # y1 = 2 * np.exp(-np.log(400) * x) # original
841
+ # y2 = 2 * np.exp(-np.log(400) * x**2) # milder
842
+ # y3 = 2 * np.exp(-np.log(400) * x**3) # even milder ✅
843
+ # y4 = 2 * np.exp(-np.log(400) * x**5) # very mild
844
+
845
+ new_lr = config.learning_rate * np.exp(-np.log(400) * (epoch / config.epochs)**3) # Steadily and exponentially decay from 2e-4 to 5e-7
846
+ optimizer.learning_rate.assign(new_lr)
847
+
848
+ print(f"\nEpoch {epoch+1}/{config.epochs} (β={beta_value.numpy():.4f}) (lr={new_lr*10000:.3f} 10-4)")
849
+ train_bar = tqdm(train_dataset, total=train_size, desc="Training")
850
+
851
+ for paired_input, target_mask, patient_id_tensor, slice_num_tensor in train_bar:
852
+
853
+ patient_id = patient_id_tensor.numpy()[0].decode('utf-8') # batch dim + bytes→str
854
+ slice_num = int(slice_num_tensor.numpy()[0])
855
+
856
+ # ✅ Prepare inputs: normalize FLAIR + one-hot encode target
857
+ flair_normalized, target_onehot = prepare_inputs(
858
+ paired_input, target_mask, config.num_classes
859
+ )
860
+
861
+ # Train step
862
+ loss, wce_loss, ufd_loss = train_step(
863
+ flair_normalized, target_onehot,
864
+ model, optimizer, class_weights,
865
+ beta_value, config.focal_gamma
866
+ )
867
+
868
+ epoch_losses.append(loss.numpy())
869
+ epoch_loss_wce.append(wce_loss.numpy())
870
+ epoch_loss_ufd.append(ufd_loss.numpy())
871
+
872
+ # Update progress bar
873
+ train_bar.set_postfix({
874
+ 'seg_loss': f"{loss.numpy():.5f}",
875
+ 'wce_loss': f"{wce_loss.numpy():.5f}",
876
+ 'ufd_loss': f"{ufd_loss.numpy():.5f}",
877
+ })
878
+
879
+ # Calculate epoch average
880
+ avg_train_loss = np.mean(epoch_losses)
881
+ avg_train_loss_wce = np.mean(epoch_loss_wce)
882
+ avg_train_loss_ufd = np.mean(epoch_loss_ufd)
883
+
884
+ history['train_loss'].append(avg_train_loss)
885
+ history['wce_loss'].append(avg_train_loss_wce)
886
+ history['ufd_loss'].append(avg_train_loss_ufd)
887
+ history['beta_value'].append(float(beta_value.numpy()))
888
+
889
+ # Validation
890
+ val_losses = []
891
+ val_losses_wce = []
892
+ val_losses_ufd = []
893
+ all_val_true = []
894
+ all_val_pred = []
895
+
896
+ for val_paired, val_target, patient_id_tensor, slice_num_tensor in val_dataset:
897
+ try:
898
+
899
+ patient_id = patient_id_tensor.numpy()[0].decode('utf-8') # batch dim + bytes→str
900
+ slice_num = int(slice_num_tensor.numpy()[0])
901
+
902
+ val_flair_norm, val_target_onehot = prepare_inputs(
903
+ val_paired, val_target, config.num_classes
904
+ )
905
+
906
+ val_pred = model(val_flair_norm, training=False)
907
+
908
+ val_loss, val_wce_loss, val_ufd_loss = adaptive_segmentation_loss(
909
+ val_target_onehot, val_pred, class_weights,
910
+ beta_value, focal_gamma=config.focal_gamma, exclude_class=exclude_class
911
+ )
912
+
913
+ # Store true and prediction values for metrics calculation
914
+ all_val_true.append(val_target_onehot)
915
+ all_val_pred.append(val_pred)
916
+
917
+ if not tf.math.is_nan(val_loss):
918
+ val_losses.append(val_loss.numpy())
919
+ val_losses_wce.append(val_wce_loss.numpy())
920
+ val_losses_ufd.append(val_ufd_loss.numpy())
921
+ except:
922
+ continue
923
+
924
+ if len(val_losses) > 0:
925
+ avg_val_loss = np.mean(val_losses)
926
+ avg_val_loss_wce = np.mean(val_losses_wce)
927
+ avg_val_loss_ufd = np.mean(val_losses_ufd)
928
+
929
+ history['val_loss'].append(avg_val_loss)
930
+ history['val_loss_wce'].append(avg_val_loss_wce)
931
+ history['val_loss_ufd'].append(avg_val_loss_ufd)
932
+
933
+ # Compute class-wise metrics
934
+ val_metrics = compute_classwise_metrics(
935
+ all_val_true, all_val_pred,
936
+ config.num_classes#, exclude_class=exclude_class
937
+ )
938
+ history['val_metrics'].append(val_metrics)
939
+
940
+ # Print validation results
941
+ epoch_time = time.time() - start_time
942
+ print(f"\n{'='*70}")
943
+ print(f"Epoch {epoch+1}/{config.epochs} Summary (Time: {epoch_time:.2f}s)")
944
+ print(f"{'='*70}")
945
+ print(f"Training Loss: {avg_train_loss:.4f} | wce_loss: {avg_train_loss_wce:.4f} | ufd_loss: {avg_train_loss_ufd:.4f}")
946
+ print(f"Validation Loss: {avg_val_loss:.4f}")
947
+ print(f"\nClass-wise Dice Scores:")
948
+ for class_name, dice_val in val_metrics['dice'].items():
949
+ if class_name != 'mean':
950
+ print(f" {class_name}: {dice_val:.4f}")
951
+ if class_name == f"class_{config.num_classes - 1}":
952
+ abwmh_val_dice = dice_val
953
+ elif class_name == f"class_1":
954
+ vent_val_dice = dice_val
955
+ print(f" Mean Dice: {val_metrics['dice']['mean']:.4f}")
956
+ print(f"\nClass-wise Precision:")
957
+ for class_name, prec_val in val_metrics['precision'].items():
958
+ if class_name != 'mean':
959
+ print(f" {class_name}: {prec_val:.4f}")
960
+ print(f" Mean Precision: {val_metrics['precision']['mean']:.4f}")
961
+ print(f"\nClass-wise Recall:")
962
+ for class_name, rec_val in val_metrics['recall'].items():
963
+ if class_name != 'mean':
964
+ print(f" {class_name}: {rec_val:.4f}")
965
+ print(f" Mean Recall: {val_metrics['recall']['mean']:.4f}")
966
+ print(f"{'='*70}\n")
967
+
968
+ # Save best model based on overall validation performance
969
+ overal_val_performance = 0.6 * abwmh_val_dice + 0.3 * vent_val_dice + 0.1 * (1 - 1*avg_val_loss)
970
+ if overal_val_performance > best_val_dice and beta_value.numpy() > 0.9:
971
+ best_val_dice = overal_val_performance
972
+ model.save_weights(f"{config.checkpoint_dir}/best_dice_model.h5")
973
+ print(f"✓ Best model saved (performance: {best_val_dice:.4f})")
974
+ else:
975
+ print("Warning: No valid validation batches")
976
+ history['val_loss'].append(float('nan'))
977
+ history['val_metrics'].append({})
978
+
979
+ # Save checkpoint
980
+ if (epoch + 1) % 5 == 0 and False:
981
+ manager.save()
982
+ print(f" 💾 Saved checkpoint")
983
+
984
+ # Generate sample images
985
+ if ((epoch + 1) % 5 == 0 or epoch == 0) or True:
986
+ generate_and_save_images(
987
+ model, example_paired, example_target,
988
+ epoch + 1, config.figures_dir, config.num_classes
989
+ )
990
+ print(f" 📊 Saved visualization")
991
+
992
+ # # Save final model
993
+ # final_model_path = config.checkpoint_dir / "final_model.h5"
994
+ # model.save(final_model_path)
995
+ # print(f"\n✅ Training complete! Final model saved to {final_model_path}")
996
+
997
+ # Save history
998
+ history_serializable = {
999
+ key: [float(val) if isinstance(val, (int, float, np.number)) else val
1000
+ for val in values]
1001
+ for key, values in history.items()
1002
+ }
1003
+
1004
+ history_file = config.checkpoint_dir / "history.json"
1005
+ with open(history_file, 'w') as f:
1006
+ json.dump(history_serializable, f, indent=2)
1007
+
1008
+ return history, history_file
1009
+
1010
+ finally:
1011
+ # CRITICAL: Always cleanup, even if training fails
1012
+ print("\n🧹 Cleaning up resources...")
1013
+
1014
+ # Delete models explicitly to break references
1015
+ try:
1016
+ del model
1017
+ del optimizer
1018
+ del checkpoint
1019
+ del manager
1020
+ del train_dataset
1021
+ del val_dataset
1022
+ print("✅ Deleted model objects")
1023
+ except Exception as e:
1024
+ print(f"⚠️ Error deleting objects: {e}")
1025
+
1026
+ # Clear GPU memory
1027
+ clear_gpu_memory()
1028
+
1029
+ # Check final GPU memory
1030
+ get_gpu_memory_info()
1031
+
1032
+ ###################### Main Execution ######################
1033
+
1034
+ if __name__ == "__main__":
1035
+
1036
+ # Example: Train a specific U-Net for 3-class, standard preprocessing, fold 0
1037
+
1038
+ config = ExperimentConfig(
1039
+ variant=3,
1040
+ preprocessing='standard',
1041
+ class_scenario='3class',
1042
+ fold_id=0,
1043
+ architecture_name='dlv3unet' # ['unet', 'attnunet', 'dlv3unet', transunet']
1044
+ )
1045
+
1046
+ history, history_path = train_net(config)
1047
+
1048
+ print("\n" + "="*70)
1049
+ print("U-NET TRAINING COMPLETE")
1050
+ print("="*70)
1051
+
models/for_WMH_Vent/model_training_scripts/trans_unet_model.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###################### Libraries ######################
2
+ # Deep Learning
3
+ import tensorflow as tf
4
+ import keras
5
+ from keras.models import Model, load_model
6
+ from keras.layers import Input, Conv2D, MaxPooling2D, Conv2DTranspose, concatenate
7
+ from keras import backend as K
8
+ from tensorflow.keras import layers, optimizers, callbacks
9
+ from keras.utils import to_categorical
10
+
11
+
12
+ def build_trans_unet_3class(input_shape=(256, 256, 1), num_classes=3):
13
+ """
14
+ TransUNet architecture for medical image segmentation
15
+ Combines CNN encoder with Transformer decoder
16
+ """
17
+ inputs = layers.Input(input_shape)
18
+
19
+ # ==================== CNN ENCODER ====================
20
+ # Stage 1
21
+ conv1 = layers.Conv2D(64, 3, padding='same', activation='relu')(inputs)
22
+ conv1 = layers.Conv2D(64, 3, padding='same', activation='relu')(conv1)
23
+ conv1 = layers.Dropout(0.1)(conv1)
24
+ pool1 = layers.MaxPooling2D(pool_size=(2, 2))(conv1)
25
+
26
+ # Stage 2
27
+ conv2 = layers.Conv2D(128, 3, padding='same', activation='relu')(pool1)
28
+ conv2 = layers.Conv2D(128, 3, padding='same', activation='relu')(conv2)
29
+ conv2 = layers.Dropout(0.1)(conv2)
30
+ pool2 = layers.MaxPooling2D(pool_size=(2, 2))(conv2)
31
+
32
+ # Stage 3
33
+ conv3 = layers.Conv2D(256, 3, padding='same', activation='relu')(pool2)
34
+ conv3 = layers.Conv2D(256, 3, padding='same', activation='relu')(conv3)
35
+ conv3 = layers.Dropout(0.2)(conv3)
36
+ pool3 = layers.MaxPooling2D(pool_size=(2, 2))(conv3)
37
+
38
+ # Stage 4
39
+ conv4 = layers.Conv2D(512, 3, padding='same', activation='relu')(pool3)
40
+ conv4 = layers.Conv2D(512, 3, padding='same', activation='relu')(conv4)
41
+ conv4 = layers.Dropout(0.2)(conv4)
42
+ pool4 = layers.MaxPooling2D(pool_size=(2, 2))(conv4)
43
+
44
+ # ==================== TRANSFORMER BOTTLENECK ====================
45
+ # Bottleneck features: 16x16x512
46
+ bottleneck = layers.Conv2D(768, 3, padding='same', activation='relu')(pool4)
47
+ bottleneck = layers.Dropout(0.3)(bottleneck)
48
+
49
+ # Prepare for transformer: reshape to sequence
50
+ batch_size = tf.shape(bottleneck)[0]
51
+ h, w = 16, 16 # feature map dimensions at bottleneck
52
+ d_model = 768 # transformer dimension
53
+
54
+ # Flatten spatial dimensions for transformer
55
+ transformer_input = layers.Reshape((h * w, d_model))(bottleneck)
56
+
57
+ # Add positional encoding
58
+ positions = tf.range(start=0, limit=h * w, delta=1)
59
+ pos_encoding = layers.Embedding(h * w, d_model)(positions)
60
+ transformer_input = transformer_input + pos_encoding
61
+
62
+ # Multi-head attention blocks
63
+ for _ in range(4): # 4 transformer layers
64
+ # Multi-head attention
65
+ attention_output = layers.MultiHeadAttention(
66
+ num_heads=8, key_dim=d_model // 8, dropout=0.1
67
+ )(transformer_input, transformer_input)
68
+ attention_output = layers.Dropout(0.1)(attention_output)
69
+ transformer_input = layers.LayerNormalization()(transformer_input + attention_output)
70
+
71
+ # Feed forward network
72
+ ffn = layers.Dense(d_model * 2, activation='relu')(transformer_input)
73
+ ffn = layers.Dropout(0.1)(ffn)
74
+ ffn = layers.Dense(d_model)(ffn)
75
+ ffn = layers.Dropout(0.1)(ffn)
76
+ transformer_input = layers.LayerNormalization()(transformer_input + ffn)
77
+
78
+ # Reshape back to spatial format
79
+ transformer_output = layers.Reshape((h, w, d_model))(transformer_input)
80
+
81
+ # Project back to bottleneck dimension
82
+ bottleneck_enhanced = layers.Conv2D(512, 1, activation='relu')(transformer_output)
83
+ bottleneck_enhanced = layers.Dropout(0.3)(bottleneck_enhanced)
84
+
85
+ # ==================== CNN DECODER ====================
86
+ # Decoder Stage 1
87
+ up1 = layers.Conv2DTranspose(512, 2, strides=2, padding='same')(bottleneck_enhanced)
88
+ concat1 = layers.Concatenate()([up1, conv4])
89
+ concat1 = layers.Dropout(0.2)(concat1)
90
+
91
+ conv_up1 = layers.Conv2D(512, 3, padding='same', activation='relu')(concat1)
92
+ conv_up1 = layers.Conv2D(512, 3, padding='same', activation='relu')(conv_up1)
93
+
94
+ # Decoder Stage 2
95
+ up2 = layers.Conv2DTranspose(256, 2, strides=2, padding='same')(conv_up1)
96
+ concat2 = layers.Concatenate()([up2, conv3])
97
+ concat2 = layers.Dropout(0.2)(concat2)
98
+
99
+ conv_up2 = layers.Conv2D(256, 3, padding='same', activation='relu')(concat2)
100
+ conv_up2 = layers.Conv2D(256, 3, padding='same', activation='relu')(conv_up2)
101
+
102
+ # Decoder Stage 3
103
+ up3 = layers.Conv2DTranspose(128, 2, strides=2, padding='same')(conv_up2)
104
+ concat3 = layers.Concatenate()([up3, conv2])
105
+ concat3 = layers.Dropout(0.1)(concat3)
106
+
107
+ conv_up3 = layers.Conv2D(128, 3, padding='same', activation='relu')(concat3)
108
+ conv_up3 = layers.Conv2D(128, 3, padding='same', activation='relu')(conv_up3)
109
+
110
+ # Decoder Stage 4
111
+ up4 = layers.Conv2DTranspose(64, 2, strides=2, padding='same')(conv_up3)
112
+ concat4 = layers.Concatenate()([up4, conv1])
113
+ concat4 = layers.Dropout(0.1)(concat4)
114
+
115
+ conv_up4 = layers.Conv2D(64, 3, padding='same', activation='relu')(concat4)
116
+ conv_up4 = layers.Conv2D(64, 3, padding='same', activation='relu')(conv_up4)
117
+
118
+ # ==================== OUTPUT LAYER ====================
119
+ if num_classes == 1:
120
+ outputs = layers.Conv2D(1, 1, activation='sigmoid')(conv_up4)
121
+ else:
122
+ outputs = layers.Conv2D(num_classes, 1, activation='softmax')(conv_up4)
123
+
124
+ model = tf.keras.Model(inputs, outputs, name='TransUNet')
125
+ return model
models/for_WMH_Vent/model_training_scripts/unet_model.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###################### Libraries ######################
2
+ # Deep Learning
3
+ import keras
4
+ from keras.models import Model
5
+ from keras.layers import Input, Conv2D, MaxPooling2D, Conv2DTranspose, concatenate
6
+
7
+
8
+ def build_unet_3class(input_shape=(256, 256, 1), num_classes=3):
9
+ """Enhanced U-Net architecture with batch normalization and dropout"""
10
+ inputs = Input(input_shape)
11
+
12
+ # Encoder with batch normalization
13
+ c1 = Conv2D(64, 3, activation='relu', padding='same')(inputs)
14
+ # c1 = keras.layers.BatchNormalization()(c1)
15
+ c1 = Conv2D(64, 3, activation='relu', padding='same')(c1)
16
+ # c1 = keras.layers.BatchNormalization()(c1)
17
+ p1 = MaxPooling2D()(c1)
18
+ p1 = keras.layers.Dropout(0.1)(p1)
19
+
20
+ c2 = Conv2D(128, 3, activation='relu', padding='same')(p1)
21
+ # c2 = keras.layers.BatchNormalization()(c2)
22
+ c2 = Conv2D(128, 3, activation='relu', padding='same')(c2)
23
+ # c2 = keras.layers.BatchNormalization()(c2)
24
+ p2 = MaxPooling2D()(c2)
25
+ p2 = keras.layers.Dropout(0.1)(p2)
26
+
27
+ c3 = Conv2D(256, 3, activation='relu', padding='same')(p2)
28
+ # c3 = keras.layers.BatchNormalization()(c3)
29
+ c3 = Conv2D(256, 3, activation='relu', padding='same')(c3)
30
+ # c3 = keras.layers.BatchNormalization()(c3)
31
+ p3 = MaxPooling2D()(c3)
32
+ p3 = keras.layers.Dropout(0.2)(p3)
33
+
34
+ c4 = Conv2D(512, 3, activation='relu', padding='same')(p3)
35
+ # c4 = keras.layers.BatchNormalization()(c4)
36
+ c4 = Conv2D(512, 3, activation='relu', padding='same')(c4)
37
+ # c4 = keras.layers.BatchNormalization()(c4)
38
+ p4 = MaxPooling2D()(c4)
39
+ p4 = keras.layers.Dropout(0.2)(p4)
40
+
41
+ # Bottleneck
42
+ c5 = Conv2D(1024, 3, activation='relu', padding='same')(p4)
43
+ # c5 = keras.layers.BatchNormalization()(c5)
44
+ c5 = Conv2D(1024, 3, activation='relu', padding='same')(c5)
45
+ # c5 = keras.layers.BatchNormalization()(c5)
46
+ c5 = keras.layers.Dropout(0.3)(c5)
47
+
48
+ # Decoder
49
+ u6 = Conv2DTranspose(512, 2, strides=2, padding='same')(c5)
50
+ u6 = concatenate([u6, c4])
51
+ u6 = keras.layers.Dropout(0.2)(u6)
52
+ c6 = Conv2D(512, 3, activation='relu', padding='same')(u6)
53
+ # c6 = keras.layers.BatchNormalization()(c6)
54
+ c6 = Conv2D(512, 3, activation='relu', padding='same')(c6)
55
+ # c6 = keras.layers.BatchNormalization()(c6)
56
+
57
+ u7 = Conv2DTranspose(256, 2, strides=2, padding='same')(c6)
58
+ u7 = concatenate([u7, c3])
59
+ u7 = keras.layers.Dropout(0.2)(u7)
60
+ c7 = Conv2D(256, 3, activation='relu', padding='same')(u7)
61
+ # c7 = keras.layers.BatchNormalization()(c7)
62
+ c7 = Conv2D(256, 3, activation='relu', padding='same')(c7)
63
+ # c7 = keras.layers.BatchNormalization()(c7)
64
+
65
+ u8 = Conv2DTranspose(128, 2, strides=2, padding='same')(c7)
66
+ u8 = concatenate([u8, c2])
67
+ u8 = keras.layers.Dropout(0.1)(u8)
68
+ c8 = Conv2D(128, 3, activation='relu', padding='same')(u8)
69
+ # c8 = keras.layers.BatchNormalization()(c8)
70
+ c8 = Conv2D(128, 3, activation='relu', padding='same')(c8)
71
+ # c8 = keras.layers.BatchNormalization()(c8)
72
+
73
+ u9 = Conv2DTranspose(64, 2, strides=2, padding='same')(c8)
74
+ u9 = concatenate([u9, c1])
75
+ u9 = keras.layers.Dropout(0.1)(u9)
76
+ c9 = Conv2D(64, 3, activation='relu', padding='same')(u9)
77
+ # c9 = keras.layers.BatchNormalization()(c9)
78
+ c9 = Conv2D(64, 3, activation='relu', padding='same')(c9)
79
+ # c9 = keras.layers.BatchNormalization()(c9)
80
+
81
+ # Output layer
82
+ if num_classes == 1:
83
+ outputs = Conv2D(1, 1, activation='sigmoid')(c9)
84
+ else:
85
+ outputs = Conv2D(num_classes, 1, activation='softmax')(c9)
86
+
87
+ return Model(inputs, outputs, name='UNet')
models/for_WMH_Vent/model_training_scripts/utility_functions.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ P4 Article - Utility Functions
3
+
4
+
5
+ Developer:
6
+ "Mahdi Bashiri Bawil"
7
+ """
8
+
9
+ import gc
10
+ import tensorflow as tf
11
+ from tensorflow.keras import backend as K
12
+
13
+ print("TensorFlow Version:", tf.__version__)
14
+
15
+ ###################### GPU Configuration ######################
16
+
17
+ # Configure GPU memory growth
18
+ physical_devices = tf.config.list_physical_devices('GPU')
19
+ if physical_devices:
20
+ try:
21
+ for device in physical_devices:
22
+ tf.config.experimental.set_memory_growth(device, True)
23
+ print("✅ GPU memory growth enabled")
24
+ print(f" Available GPUs: {len(physical_devices)}")
25
+ except RuntimeError as e:
26
+ print(f"GPU configuration error: {e}")
27
+ else:
28
+ print("⚠️ No GPU detected - training will be slow")
29
+
30
+ """
31
+ GPU Memory Management for Sequential Experiments
32
+ To properly release memory between experiments
33
+ """
34
+
35
+
36
+ def clear_gpu_memory():
37
+ """
38
+ Comprehensive GPU memory cleanup between experiments
39
+ Call this after each experiment completes
40
+ """
41
+ print("\n" + "="*70)
42
+ print("CLEANING UP GPU MEMORY")
43
+ print("="*70)
44
+
45
+ # Clear Keras session
46
+ K.clear_session()
47
+ print("✅ Cleared Keras session")
48
+
49
+ # Force garbage collection multiple times
50
+ for _ in range(3):
51
+ gc.collect()
52
+ print("✅ Ran garbage collection (3 passes)")
53
+
54
+ # Reset TensorFlow graphs
55
+ tf.compat.v1.reset_default_graph()
56
+ print("✅ Reset default graph")
57
+
58
+ # Additional cleanup for TF 2.x
59
+ try:
60
+ # Clear any cached tensors
61
+ tf.config.experimental.reset_memory_stats('GPU:0')
62
+ print("✅ Reset GPU memory stats")
63
+ except:
64
+ pass
65
+
66
+ # CRITICAL: Reset GPU memory allocator
67
+ # This forces TensorFlow to release memory back to the system
68
+ try:
69
+ physical_devices = tf.config.list_physical_devices('GPU')
70
+ if physical_devices:
71
+ # Disable and re-enable memory growth to flush allocator
72
+ for device in physical_devices:
73
+ tf.config.experimental.set_memory_growth(device, False)
74
+ tf.config.experimental.set_memory_growth(device, True)
75
+ print("✅ Reset memory growth (flushed allocator)")
76
+ except Exception as e:
77
+ print(f"⚠️ Could not reset memory growth: {e}")
78
+
79
+ print("="*70 + "\n")
80
+
81
+
82
+ def get_gpu_memory_info():
83
+ """
84
+ Print current GPU memory usage
85
+ Useful for monitoring memory leaks
86
+ """
87
+ try:
88
+ gpu_devices = tf.config.list_physical_devices('GPU')
89
+ if gpu_devices:
90
+ for device in gpu_devices:
91
+ details = tf.config.experimental.get_memory_info(device.name.replace('/physical_device:', ''))
92
+ current_mb = details['current'] / 1024**2
93
+ peak_mb = details['peak'] / 1024**2
94
+ print(f"GPU Memory - Current: {current_mb:.1f} MB, Peak: {peak_mb:.1f} MB")
95
+ except Exception as e:
96
+ print(f"Could not get GPU memory info: {e}")
models/for_WMH_Vent/results_fold_avg_var_1_zscore2/models/standard_3class/download_models.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Visit our Hugging Face link for downloading the trained models.