Upload 31 files
Browse files- models/for_WMH_Vent/class_weights/class_weights_fold0_standard_3class.json +27 -0
- models/for_WMH_Vent/class_weights/class_weights_fold1_standard_3class.json +27 -0
- models/for_WMH_Vent/class_weights/class_weights_fold2_standard_3class.json +27 -0
- models/for_WMH_Vent/class_weights/class_weights_fold3_standard_3class.json +27 -0
- models/for_WMH_Vent/data_splits/concat_fold_assignments.json +475 -0
- models/for_WMH_Vent/data_splits/fold_assignments.json +543 -0
- models/for_WMH_Vent/data_splits/for_assignment.py +234 -0
- models/for_WMH_Vent/data_splits/local_fold_assignments.json +421 -0
- models/for_WMH_Vent/data_splits/public_fold_assignments.json +102 -0
- models/for_WMH_Vent/download_models.txt +1 -0
- models/for_WMH_Vent/folds_results_zscore2_all/per_class_summary.csv +9 -0
- models/for_WMH_Vent/folds_results_zscore2_all/test_metrics_all_variants_folds.csv +27 -0
- models/for_WMH_Vent/folds_results_zscore2_all/training_info_all_variants_folds.csv +17 -0
- models/for_WMH_Vent/folds_results_zscore2_all/variant_comparison_test.csv +5 -0
- models/for_WMH_Vent/folds_results_zscore2_all/variant_comparison_training.csv +5 -0
- models/for_WMH_Vent/model_training_scripts/attn_unet_model.py +85 -0
- models/for_WMH_Vent/model_training_scripts/base_runner_all.py +23 -0
- models/for_WMH_Vent/model_training_scripts/dlv3_unet_model.py +198 -0
- models/for_WMH_Vent/model_training_scripts/dlv3_unet_model_GN.py +247 -0
- models/for_WMH_Vent/model_training_scripts/p4_compute_class_weights.py +353 -0
- models/for_WMH_Vent/model_training_scripts/p4_data_loader.py +912 -0
- models/for_WMH_Vent/model_training_scripts/p4_error_analysis.py +1033 -0
- models/for_WMH_Vent/model_training_scripts/p4_folds_results_aggregator.py +611 -0
- models/for_WMH_Vent/model_training_scripts/p4_inference.py +1146 -0
- models/for_WMH_Vent/model_training_scripts/p4_run_experiments_all.py +576 -0
- models/for_WMH_Vent/model_training_scripts/p4_unet_viz.py +640 -0
- models/for_WMH_Vent/model_training_scripts/p4_variant_all_net.py +1051 -0
- models/for_WMH_Vent/model_training_scripts/trans_unet_model.py +125 -0
- models/for_WMH_Vent/model_training_scripts/unet_model.py +87 -0
- models/for_WMH_Vent/model_training_scripts/utility_functions.py +96 -0
- models/for_WMH_Vent/results_fold_avg_var_1_zscore2/models/standard_3class/download_models.txt +1 -0
models/for_WMH_Vent/class_weights/class_weights_fold0_standard_3class.json
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"fold_id": 0,
|
| 3 |
+
"class_scenario": "3class",
|
| 4 |
+
"preprocessing": "standard",
|
| 5 |
+
"num_classes": 3,
|
| 6 |
+
"total_pixels": 119144448,
|
| 7 |
+
"class_pixel_counts": [
|
| 8 |
+
118420367,
|
| 9 |
+
496384,
|
| 10 |
+
227697
|
| 11 |
+
],
|
| 12 |
+
"class_frequencies": [
|
| 13 |
+
0.993922662682528,
|
| 14 |
+
0.004166236936193619,
|
| 15 |
+
0.0019111003812783622
|
| 16 |
+
],
|
| 17 |
+
"class_weights": [
|
| 18 |
+
0.003950922707703595,
|
| 19 |
+
0.9423307646632635,
|
| 20 |
+
2.0537183126290333
|
| 21 |
+
],
|
| 22 |
+
"class_names": [
|
| 23 |
+
"Background",
|
| 24 |
+
"Ventricles",
|
| 25 |
+
"Abnormal WMH"
|
| 26 |
+
]
|
| 27 |
+
}
|
models/for_WMH_Vent/class_weights/class_weights_fold1_standard_3class.json
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"fold_id": 1,
|
| 3 |
+
"class_scenario": "3class",
|
| 4 |
+
"preprocessing": "standard",
|
| 5 |
+
"num_classes": 3,
|
| 6 |
+
"total_pixels": 119341056,
|
| 7 |
+
"class_pixel_counts": [
|
| 8 |
+
118646442,
|
| 9 |
+
470627,
|
| 10 |
+
223987
|
| 11 |
+
],
|
| 12 |
+
"class_frequencies": [
|
| 13 |
+
0.994179588958891,
|
| 14 |
+
0.003943546469037445,
|
| 15 |
+
0.0018768645720714924
|
| 16 |
+
],
|
| 17 |
+
"class_weights": [
|
| 18 |
+
0.003834061426337229,
|
| 19 |
+
0.96633402011123,
|
| 20 |
+
2.029831918462433
|
| 21 |
+
],
|
| 22 |
+
"class_names": [
|
| 23 |
+
"Background",
|
| 24 |
+
"Ventricles",
|
| 25 |
+
"Abnormal WMH"
|
| 26 |
+
]
|
| 27 |
+
}
|
models/for_WMH_Vent/class_weights/class_weights_fold2_standard_3class.json
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"fold_id": 2,
|
| 3 |
+
"class_scenario": "3class",
|
| 4 |
+
"preprocessing": "standard",
|
| 5 |
+
"num_classes": 3,
|
| 6 |
+
"total_pixels": 119472128,
|
| 7 |
+
"class_pixel_counts": [
|
| 8 |
+
118787277,
|
| 9 |
+
464952,
|
| 10 |
+
219899
|
| 11 |
+
],
|
| 12 |
+
"class_frequencies": [
|
| 13 |
+
0.994267692293888,
|
| 14 |
+
0.0038917194142553484,
|
| 15 |
+
0.001840588291856658
|
| 16 |
+
],
|
| 17 |
+
"class_weights": [
|
| 18 |
+
0.0037673539050414257,
|
| 19 |
+
0.9622481463361134,
|
| 20 |
+
2.033984499758845
|
| 21 |
+
],
|
| 22 |
+
"class_names": [
|
| 23 |
+
"Background",
|
| 24 |
+
"Ventricles",
|
| 25 |
+
"Abnormal WMH"
|
| 26 |
+
]
|
| 27 |
+
}
|
models/for_WMH_Vent/class_weights/class_weights_fold3_standard_3class.json
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"fold_id": 3,
|
| 3 |
+
"class_scenario": "3class",
|
| 4 |
+
"preprocessing": "standard",
|
| 5 |
+
"num_classes": 3,
|
| 6 |
+
"total_pixels": 119734272,
|
| 7 |
+
"class_pixel_counts": [
|
| 8 |
+
118973104,
|
| 9 |
+
509903,
|
| 10 |
+
251265
|
| 11 |
+
],
|
| 12 |
+
"class_frequencies": [
|
| 13 |
+
0.9936428560738232,
|
| 14 |
+
0.004258621959132971,
|
| 15 |
+
0.0020985219670438216
|
| 16 |
+
],
|
| 17 |
+
"class_weights": [
|
| 18 |
+
0.004240031541573218,
|
| 19 |
+
0.9890739908996539,
|
| 20 |
+
2.006685977558773
|
| 21 |
+
],
|
| 22 |
+
"class_names": [
|
| 23 |
+
"Background",
|
| 24 |
+
"Ventricles",
|
| 25 |
+
"Abnormal WMH"
|
| 26 |
+
]
|
| 27 |
+
}
|
models/for_WMH_Vent/data_splits/concat_fold_assignments.json
ADDED
|
@@ -0,0 +1,475 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"metadata": {
|
| 3 |
+
"datasets": [
|
| 4 |
+
"Local_SAI",
|
| 5 |
+
"Public_MSSEG"
|
| 6 |
+
],
|
| 7 |
+
"total_patients": 115,
|
| 8 |
+
"test_patients": 13,
|
| 9 |
+
"trainval_patients": 102,
|
| 10 |
+
"local_split": "70/10/20",
|
| 11 |
+
"public_split": "60/20/20",
|
| 12 |
+
"n_folds": 4,
|
| 13 |
+
"random_seed": 42
|
| 14 |
+
},
|
| 15 |
+
"test_set": {
|
| 16 |
+
"patients": [
|
| 17 |
+
"110012",
|
| 18 |
+
"105549",
|
| 19 |
+
"109816",
|
| 20 |
+
"105074",
|
| 21 |
+
"106780",
|
| 22 |
+
"107680",
|
| 23 |
+
"108807",
|
| 24 |
+
"106063",
|
| 25 |
+
"114585",
|
| 26 |
+
"111489",
|
| 27 |
+
"c01p04",
|
| 28 |
+
"c07p05",
|
| 29 |
+
"c08p04"
|
| 30 |
+
],
|
| 31 |
+
"n_patients": 13
|
| 32 |
+
},
|
| 33 |
+
"folds": {
|
| 34 |
+
"fold_0": {
|
| 35 |
+
"train_patients": [
|
| 36 |
+
"109395",
|
| 37 |
+
"115788",
|
| 38 |
+
"113845",
|
| 39 |
+
"114770",
|
| 40 |
+
"102313",
|
| 41 |
+
"104797",
|
| 42 |
+
"111189",
|
| 43 |
+
"105597",
|
| 44 |
+
"111140",
|
| 45 |
+
"106270",
|
| 46 |
+
"114836",
|
| 47 |
+
"108295",
|
| 48 |
+
"104518",
|
| 49 |
+
"110218",
|
| 50 |
+
"110784",
|
| 51 |
+
"101627",
|
| 52 |
+
"104280",
|
| 53 |
+
"107966",
|
| 54 |
+
"101228",
|
| 55 |
+
"104420",
|
| 56 |
+
"109944",
|
| 57 |
+
"114903",
|
| 58 |
+
"112765",
|
| 59 |
+
"106200",
|
| 60 |
+
"106506",
|
| 61 |
+
"106536",
|
| 62 |
+
"112055",
|
| 63 |
+
"104447",
|
| 64 |
+
"106976",
|
| 65 |
+
"105978",
|
| 66 |
+
"110543",
|
| 67 |
+
"114058",
|
| 68 |
+
"113394",
|
| 69 |
+
"107739",
|
| 70 |
+
"112657",
|
| 71 |
+
"111008",
|
| 72 |
+
"105911",
|
| 73 |
+
"111852",
|
| 74 |
+
"105465",
|
| 75 |
+
"114128",
|
| 76 |
+
"110280",
|
| 77 |
+
"112414",
|
| 78 |
+
"105302",
|
| 79 |
+
"107455",
|
| 80 |
+
"110327",
|
| 81 |
+
"114990",
|
| 82 |
+
"112730",
|
| 83 |
+
"104453",
|
| 84 |
+
"111691",
|
| 85 |
+
"114454",
|
| 86 |
+
"104474",
|
| 87 |
+
"104252",
|
| 88 |
+
"109654",
|
| 89 |
+
"104937",
|
| 90 |
+
"104871",
|
| 91 |
+
"107508",
|
| 92 |
+
"114525",
|
| 93 |
+
"115588",
|
| 94 |
+
"110540",
|
| 95 |
+
"109267",
|
| 96 |
+
"107539",
|
| 97 |
+
"108344",
|
| 98 |
+
"112659",
|
| 99 |
+
"112776",
|
| 100 |
+
"113046",
|
| 101 |
+
"107233",
|
| 102 |
+
"102035",
|
| 103 |
+
"106905",
|
| 104 |
+
"107997",
|
| 105 |
+
"112378",
|
| 106 |
+
"104520",
|
| 107 |
+
"106639",
|
| 108 |
+
"104670",
|
| 109 |
+
"104899",
|
| 110 |
+
"115628",
|
| 111 |
+
"108444",
|
| 112 |
+
"109923",
|
| 113 |
+
"110157",
|
| 114 |
+
"114304",
|
| 115 |
+
"114266",
|
| 116 |
+
"c08p03",
|
| 117 |
+
"c01p01",
|
| 118 |
+
"c08p02",
|
| 119 |
+
"c07p03",
|
| 120 |
+
"c07p04",
|
| 121 |
+
"c01p02",
|
| 122 |
+
"c07p01",
|
| 123 |
+
"c08p05",
|
| 124 |
+
"c07p02"
|
| 125 |
+
],
|
| 126 |
+
"val_patients": [
|
| 127 |
+
"108726",
|
| 128 |
+
"105917",
|
| 129 |
+
"105755",
|
| 130 |
+
"109141",
|
| 131 |
+
"110497",
|
| 132 |
+
"112997",
|
| 133 |
+
"104810",
|
| 134 |
+
"108975",
|
| 135 |
+
"107130",
|
| 136 |
+
"107630",
|
| 137 |
+
"c01p05",
|
| 138 |
+
"c08p01",
|
| 139 |
+
"c01p03"
|
| 140 |
+
],
|
| 141 |
+
"n_train": 89,
|
| 142 |
+
"n_val": 13
|
| 143 |
+
},
|
| 144 |
+
"fold_1": {
|
| 145 |
+
"train_patients": [
|
| 146 |
+
"108726",
|
| 147 |
+
"105917",
|
| 148 |
+
"105755",
|
| 149 |
+
"109141",
|
| 150 |
+
"110497",
|
| 151 |
+
"112997",
|
| 152 |
+
"104810",
|
| 153 |
+
"108975",
|
| 154 |
+
"107130",
|
| 155 |
+
"107630",
|
| 156 |
+
"114836",
|
| 157 |
+
"108295",
|
| 158 |
+
"104518",
|
| 159 |
+
"110218",
|
| 160 |
+
"110784",
|
| 161 |
+
"101627",
|
| 162 |
+
"104280",
|
| 163 |
+
"107966",
|
| 164 |
+
"101228",
|
| 165 |
+
"104420",
|
| 166 |
+
"109944",
|
| 167 |
+
"114903",
|
| 168 |
+
"112765",
|
| 169 |
+
"106200",
|
| 170 |
+
"106506",
|
| 171 |
+
"106536",
|
| 172 |
+
"112055",
|
| 173 |
+
"104447",
|
| 174 |
+
"106976",
|
| 175 |
+
"105978",
|
| 176 |
+
"110543",
|
| 177 |
+
"114058",
|
| 178 |
+
"113394",
|
| 179 |
+
"107739",
|
| 180 |
+
"112657",
|
| 181 |
+
"111008",
|
| 182 |
+
"105911",
|
| 183 |
+
"111852",
|
| 184 |
+
"105465",
|
| 185 |
+
"114128",
|
| 186 |
+
"110280",
|
| 187 |
+
"112414",
|
| 188 |
+
"105302",
|
| 189 |
+
"107455",
|
| 190 |
+
"110327",
|
| 191 |
+
"114990",
|
| 192 |
+
"112730",
|
| 193 |
+
"104453",
|
| 194 |
+
"111691",
|
| 195 |
+
"114454",
|
| 196 |
+
"104474",
|
| 197 |
+
"104252",
|
| 198 |
+
"109654",
|
| 199 |
+
"104937",
|
| 200 |
+
"104871",
|
| 201 |
+
"107508",
|
| 202 |
+
"114525",
|
| 203 |
+
"115588",
|
| 204 |
+
"110540",
|
| 205 |
+
"109267",
|
| 206 |
+
"107539",
|
| 207 |
+
"108344",
|
| 208 |
+
"112659",
|
| 209 |
+
"112776",
|
| 210 |
+
"113046",
|
| 211 |
+
"107233",
|
| 212 |
+
"102035",
|
| 213 |
+
"106905",
|
| 214 |
+
"107997",
|
| 215 |
+
"112378",
|
| 216 |
+
"104520",
|
| 217 |
+
"106639",
|
| 218 |
+
"104670",
|
| 219 |
+
"104899",
|
| 220 |
+
"115628",
|
| 221 |
+
"108444",
|
| 222 |
+
"109923",
|
| 223 |
+
"110157",
|
| 224 |
+
"114304",
|
| 225 |
+
"114266",
|
| 226 |
+
"c01p05",
|
| 227 |
+
"c08p01",
|
| 228 |
+
"c01p03",
|
| 229 |
+
"c07p03",
|
| 230 |
+
"c07p04",
|
| 231 |
+
"c01p02",
|
| 232 |
+
"c07p01",
|
| 233 |
+
"c08p05",
|
| 234 |
+
"c07p02"
|
| 235 |
+
],
|
| 236 |
+
"val_patients": [
|
| 237 |
+
"109395",
|
| 238 |
+
"115788",
|
| 239 |
+
"113845",
|
| 240 |
+
"114770",
|
| 241 |
+
"102313",
|
| 242 |
+
"104797",
|
| 243 |
+
"111189",
|
| 244 |
+
"105597",
|
| 245 |
+
"111140",
|
| 246 |
+
"106270",
|
| 247 |
+
"c08p03",
|
| 248 |
+
"c01p01",
|
| 249 |
+
"c08p02"
|
| 250 |
+
],
|
| 251 |
+
"n_train": 89,
|
| 252 |
+
"n_val": 13
|
| 253 |
+
},
|
| 254 |
+
"fold_2": {
|
| 255 |
+
"train_patients": [
|
| 256 |
+
"108726",
|
| 257 |
+
"105917",
|
| 258 |
+
"105755",
|
| 259 |
+
"109141",
|
| 260 |
+
"110497",
|
| 261 |
+
"112997",
|
| 262 |
+
"104810",
|
| 263 |
+
"108975",
|
| 264 |
+
"107130",
|
| 265 |
+
"107630",
|
| 266 |
+
"109395",
|
| 267 |
+
"115788",
|
| 268 |
+
"113845",
|
| 269 |
+
"114770",
|
| 270 |
+
"102313",
|
| 271 |
+
"104797",
|
| 272 |
+
"111189",
|
| 273 |
+
"105597",
|
| 274 |
+
"111140",
|
| 275 |
+
"106270",
|
| 276 |
+
"109944",
|
| 277 |
+
"114903",
|
| 278 |
+
"112765",
|
| 279 |
+
"106200",
|
| 280 |
+
"106506",
|
| 281 |
+
"106536",
|
| 282 |
+
"112055",
|
| 283 |
+
"104447",
|
| 284 |
+
"106976",
|
| 285 |
+
"105978",
|
| 286 |
+
"110543",
|
| 287 |
+
"114058",
|
| 288 |
+
"113394",
|
| 289 |
+
"107739",
|
| 290 |
+
"112657",
|
| 291 |
+
"111008",
|
| 292 |
+
"105911",
|
| 293 |
+
"111852",
|
| 294 |
+
"105465",
|
| 295 |
+
"114128",
|
| 296 |
+
"110280",
|
| 297 |
+
"112414",
|
| 298 |
+
"105302",
|
| 299 |
+
"107455",
|
| 300 |
+
"110327",
|
| 301 |
+
"114990",
|
| 302 |
+
"112730",
|
| 303 |
+
"104453",
|
| 304 |
+
"111691",
|
| 305 |
+
"114454",
|
| 306 |
+
"104474",
|
| 307 |
+
"104252",
|
| 308 |
+
"109654",
|
| 309 |
+
"104937",
|
| 310 |
+
"104871",
|
| 311 |
+
"107508",
|
| 312 |
+
"114525",
|
| 313 |
+
"115588",
|
| 314 |
+
"110540",
|
| 315 |
+
"109267",
|
| 316 |
+
"107539",
|
| 317 |
+
"108344",
|
| 318 |
+
"112659",
|
| 319 |
+
"112776",
|
| 320 |
+
"113046",
|
| 321 |
+
"107233",
|
| 322 |
+
"102035",
|
| 323 |
+
"106905",
|
| 324 |
+
"107997",
|
| 325 |
+
"112378",
|
| 326 |
+
"104520",
|
| 327 |
+
"106639",
|
| 328 |
+
"104670",
|
| 329 |
+
"104899",
|
| 330 |
+
"115628",
|
| 331 |
+
"108444",
|
| 332 |
+
"109923",
|
| 333 |
+
"110157",
|
| 334 |
+
"114304",
|
| 335 |
+
"114266",
|
| 336 |
+
"c01p05",
|
| 337 |
+
"c08p01",
|
| 338 |
+
"c01p03",
|
| 339 |
+
"c08p03",
|
| 340 |
+
"c01p01",
|
| 341 |
+
"c08p02",
|
| 342 |
+
"c07p01",
|
| 343 |
+
"c08p05",
|
| 344 |
+
"c07p02"
|
| 345 |
+
],
|
| 346 |
+
"val_patients": [
|
| 347 |
+
"114836",
|
| 348 |
+
"108295",
|
| 349 |
+
"104518",
|
| 350 |
+
"110218",
|
| 351 |
+
"110784",
|
| 352 |
+
"101627",
|
| 353 |
+
"104280",
|
| 354 |
+
"107966",
|
| 355 |
+
"101228",
|
| 356 |
+
"104420",
|
| 357 |
+
"c07p03",
|
| 358 |
+
"c07p04",
|
| 359 |
+
"c01p02"
|
| 360 |
+
],
|
| 361 |
+
"n_train": 89,
|
| 362 |
+
"n_val": 13
|
| 363 |
+
},
|
| 364 |
+
"fold_3": {
|
| 365 |
+
"train_patients": [
|
| 366 |
+
"108726",
|
| 367 |
+
"105917",
|
| 368 |
+
"105755",
|
| 369 |
+
"109141",
|
| 370 |
+
"110497",
|
| 371 |
+
"112997",
|
| 372 |
+
"104810",
|
| 373 |
+
"108975",
|
| 374 |
+
"107130",
|
| 375 |
+
"107630",
|
| 376 |
+
"109395",
|
| 377 |
+
"115788",
|
| 378 |
+
"113845",
|
| 379 |
+
"114770",
|
| 380 |
+
"102313",
|
| 381 |
+
"104797",
|
| 382 |
+
"111189",
|
| 383 |
+
"105597",
|
| 384 |
+
"111140",
|
| 385 |
+
"106270",
|
| 386 |
+
"114836",
|
| 387 |
+
"108295",
|
| 388 |
+
"104518",
|
| 389 |
+
"110218",
|
| 390 |
+
"110784",
|
| 391 |
+
"101627",
|
| 392 |
+
"104280",
|
| 393 |
+
"107966",
|
| 394 |
+
"101228",
|
| 395 |
+
"104420",
|
| 396 |
+
"110543",
|
| 397 |
+
"114058",
|
| 398 |
+
"113394",
|
| 399 |
+
"107739",
|
| 400 |
+
"112657",
|
| 401 |
+
"111008",
|
| 402 |
+
"105911",
|
| 403 |
+
"111852",
|
| 404 |
+
"105465",
|
| 405 |
+
"114128",
|
| 406 |
+
"110280",
|
| 407 |
+
"112414",
|
| 408 |
+
"105302",
|
| 409 |
+
"107455",
|
| 410 |
+
"110327",
|
| 411 |
+
"114990",
|
| 412 |
+
"112730",
|
| 413 |
+
"104453",
|
| 414 |
+
"111691",
|
| 415 |
+
"114454",
|
| 416 |
+
"104474",
|
| 417 |
+
"104252",
|
| 418 |
+
"109654",
|
| 419 |
+
"104937",
|
| 420 |
+
"104871",
|
| 421 |
+
"107508",
|
| 422 |
+
"114525",
|
| 423 |
+
"115588",
|
| 424 |
+
"110540",
|
| 425 |
+
"109267",
|
| 426 |
+
"107539",
|
| 427 |
+
"108344",
|
| 428 |
+
"112659",
|
| 429 |
+
"112776",
|
| 430 |
+
"113046",
|
| 431 |
+
"107233",
|
| 432 |
+
"102035",
|
| 433 |
+
"106905",
|
| 434 |
+
"107997",
|
| 435 |
+
"112378",
|
| 436 |
+
"104520",
|
| 437 |
+
"106639",
|
| 438 |
+
"104670",
|
| 439 |
+
"104899",
|
| 440 |
+
"115628",
|
| 441 |
+
"108444",
|
| 442 |
+
"109923",
|
| 443 |
+
"110157",
|
| 444 |
+
"114304",
|
| 445 |
+
"114266",
|
| 446 |
+
"c01p05",
|
| 447 |
+
"c08p01",
|
| 448 |
+
"c01p03",
|
| 449 |
+
"c08p03",
|
| 450 |
+
"c01p01",
|
| 451 |
+
"c08p02",
|
| 452 |
+
"c07p03",
|
| 453 |
+
"c07p04",
|
| 454 |
+
"c01p02"
|
| 455 |
+
],
|
| 456 |
+
"val_patients": [
|
| 457 |
+
"109944",
|
| 458 |
+
"114903",
|
| 459 |
+
"112765",
|
| 460 |
+
"106200",
|
| 461 |
+
"106506",
|
| 462 |
+
"106536",
|
| 463 |
+
"112055",
|
| 464 |
+
"104447",
|
| 465 |
+
"106976",
|
| 466 |
+
"105978",
|
| 467 |
+
"c07p01",
|
| 468 |
+
"c08p05",
|
| 469 |
+
"c07p02"
|
| 470 |
+
],
|
| 471 |
+
"n_train": 89,
|
| 472 |
+
"n_val": 13
|
| 473 |
+
}
|
| 474 |
+
}
|
| 475 |
+
}
|
models/for_WMH_Vent/data_splits/fold_assignments.json
ADDED
|
@@ -0,0 +1,543 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"metadata": {
|
| 3 |
+
"total_patients": 115,
|
| 4 |
+
"test_patients": 23,
|
| 5 |
+
"trainval_patients": 92,
|
| 6 |
+
"n_folds": 5,
|
| 7 |
+
"random_seed": 42,
|
| 8 |
+
"datasets": [
|
| 9 |
+
"Local_SAI",
|
| 10 |
+
"Public_MSSEG"
|
| 11 |
+
]
|
| 12 |
+
},
|
| 13 |
+
"test_set": {
|
| 14 |
+
"patients": [
|
| 15 |
+
"112776",
|
| 16 |
+
"104252",
|
| 17 |
+
"107539",
|
| 18 |
+
"111140",
|
| 19 |
+
"104518",
|
| 20 |
+
"107997",
|
| 21 |
+
"111189",
|
| 22 |
+
"110543",
|
| 23 |
+
"108344",
|
| 24 |
+
"104520",
|
| 25 |
+
"c01p01",
|
| 26 |
+
"107130",
|
| 27 |
+
"113394",
|
| 28 |
+
"c08p04",
|
| 29 |
+
"105074",
|
| 30 |
+
"101228",
|
| 31 |
+
"111691",
|
| 32 |
+
"105978",
|
| 33 |
+
"c07p01",
|
| 34 |
+
"109267",
|
| 35 |
+
"114836",
|
| 36 |
+
"c08p03",
|
| 37 |
+
"104670"
|
| 38 |
+
],
|
| 39 |
+
"n_patients": 23
|
| 40 |
+
},
|
| 41 |
+
"folds": {
|
| 42 |
+
"fold_0": {
|
| 43 |
+
"train_patients": [
|
| 44 |
+
"102035",
|
| 45 |
+
"102313",
|
| 46 |
+
"104280",
|
| 47 |
+
"104447",
|
| 48 |
+
"104453",
|
| 49 |
+
"104474",
|
| 50 |
+
"104797",
|
| 51 |
+
"104810",
|
| 52 |
+
"104899",
|
| 53 |
+
"105302",
|
| 54 |
+
"105465",
|
| 55 |
+
"105549",
|
| 56 |
+
"105597",
|
| 57 |
+
"105755",
|
| 58 |
+
"105917",
|
| 59 |
+
"106063",
|
| 60 |
+
"106200",
|
| 61 |
+
"106506",
|
| 62 |
+
"106536",
|
| 63 |
+
"106639",
|
| 64 |
+
"106905",
|
| 65 |
+
"107233",
|
| 66 |
+
"107455",
|
| 67 |
+
"107508",
|
| 68 |
+
"107630",
|
| 69 |
+
"107680",
|
| 70 |
+
"107739",
|
| 71 |
+
"108295",
|
| 72 |
+
"108444",
|
| 73 |
+
"108726",
|
| 74 |
+
"109141",
|
| 75 |
+
"109395",
|
| 76 |
+
"109654",
|
| 77 |
+
"109923",
|
| 78 |
+
"109944",
|
| 79 |
+
"110012",
|
| 80 |
+
"110157",
|
| 81 |
+
"110280",
|
| 82 |
+
"110327",
|
| 83 |
+
"110497",
|
| 84 |
+
"110540",
|
| 85 |
+
"110784",
|
| 86 |
+
"111489",
|
| 87 |
+
"111852",
|
| 88 |
+
"112055",
|
| 89 |
+
"112378",
|
| 90 |
+
"112414",
|
| 91 |
+
"112657",
|
| 92 |
+
"112730",
|
| 93 |
+
"112765",
|
| 94 |
+
"112997",
|
| 95 |
+
"113046",
|
| 96 |
+
"114058",
|
| 97 |
+
"114128",
|
| 98 |
+
"114266",
|
| 99 |
+
"114304",
|
| 100 |
+
"114525",
|
| 101 |
+
"114585",
|
| 102 |
+
"114770",
|
| 103 |
+
"114903",
|
| 104 |
+
"114990",
|
| 105 |
+
"115588",
|
| 106 |
+
"115628",
|
| 107 |
+
"115788",
|
| 108 |
+
"c01p02",
|
| 109 |
+
"c01p03",
|
| 110 |
+
"c01p05",
|
| 111 |
+
"c07p02",
|
| 112 |
+
"c07p03",
|
| 113 |
+
"c07p04",
|
| 114 |
+
"c07p05",
|
| 115 |
+
"c08p02",
|
| 116 |
+
"c08p05"
|
| 117 |
+
],
|
| 118 |
+
"val_patients": [
|
| 119 |
+
"101627",
|
| 120 |
+
"104420",
|
| 121 |
+
"104871",
|
| 122 |
+
"104937",
|
| 123 |
+
"105911",
|
| 124 |
+
"106270",
|
| 125 |
+
"106780",
|
| 126 |
+
"106976",
|
| 127 |
+
"107966",
|
| 128 |
+
"108807",
|
| 129 |
+
"108975",
|
| 130 |
+
"109816",
|
| 131 |
+
"110218",
|
| 132 |
+
"111008",
|
| 133 |
+
"112659",
|
| 134 |
+
"113845",
|
| 135 |
+
"114454",
|
| 136 |
+
"c01p04",
|
| 137 |
+
"c08p01"
|
| 138 |
+
],
|
| 139 |
+
"n_train": 73,
|
| 140 |
+
"n_val": 19
|
| 141 |
+
},
|
| 142 |
+
"fold_1": {
|
| 143 |
+
"train_patients": [
|
| 144 |
+
"101627",
|
| 145 |
+
"102035",
|
| 146 |
+
"102313",
|
| 147 |
+
"104280",
|
| 148 |
+
"104420",
|
| 149 |
+
"104453",
|
| 150 |
+
"104474",
|
| 151 |
+
"104797",
|
| 152 |
+
"104871",
|
| 153 |
+
"104937",
|
| 154 |
+
"105302",
|
| 155 |
+
"105465",
|
| 156 |
+
"105755",
|
| 157 |
+
"105911",
|
| 158 |
+
"105917",
|
| 159 |
+
"106063",
|
| 160 |
+
"106200",
|
| 161 |
+
"106270",
|
| 162 |
+
"106506",
|
| 163 |
+
"106536",
|
| 164 |
+
"106639",
|
| 165 |
+
"106780",
|
| 166 |
+
"106905",
|
| 167 |
+
"106976",
|
| 168 |
+
"107233",
|
| 169 |
+
"107630",
|
| 170 |
+
"107966",
|
| 171 |
+
"108295",
|
| 172 |
+
"108444",
|
| 173 |
+
"108726",
|
| 174 |
+
"108807",
|
| 175 |
+
"108975",
|
| 176 |
+
"109141",
|
| 177 |
+
"109654",
|
| 178 |
+
"109816",
|
| 179 |
+
"109944",
|
| 180 |
+
"110157",
|
| 181 |
+
"110218",
|
| 182 |
+
"110280",
|
| 183 |
+
"110327",
|
| 184 |
+
"110497",
|
| 185 |
+
"110540",
|
| 186 |
+
"110784",
|
| 187 |
+
"111008",
|
| 188 |
+
"111489",
|
| 189 |
+
"111852",
|
| 190 |
+
"112055",
|
| 191 |
+
"112378",
|
| 192 |
+
"112414",
|
| 193 |
+
"112657",
|
| 194 |
+
"112659",
|
| 195 |
+
"112730",
|
| 196 |
+
"112765",
|
| 197 |
+
"113845",
|
| 198 |
+
"114304",
|
| 199 |
+
"114454",
|
| 200 |
+
"114525",
|
| 201 |
+
"114585",
|
| 202 |
+
"114770",
|
| 203 |
+
"114903",
|
| 204 |
+
"115628",
|
| 205 |
+
"115788",
|
| 206 |
+
"c01p02",
|
| 207 |
+
"c01p03",
|
| 208 |
+
"c01p04",
|
| 209 |
+
"c01p05",
|
| 210 |
+
"c07p02",
|
| 211 |
+
"c07p03",
|
| 212 |
+
"c07p04",
|
| 213 |
+
"c07p05",
|
| 214 |
+
"c08p01",
|
| 215 |
+
"c08p02",
|
| 216 |
+
"c08p05"
|
| 217 |
+
],
|
| 218 |
+
"val_patients": [
|
| 219 |
+
"104447",
|
| 220 |
+
"104810",
|
| 221 |
+
"104899",
|
| 222 |
+
"105549",
|
| 223 |
+
"105597",
|
| 224 |
+
"107455",
|
| 225 |
+
"107508",
|
| 226 |
+
"107680",
|
| 227 |
+
"107739",
|
| 228 |
+
"109395",
|
| 229 |
+
"109923",
|
| 230 |
+
"110012",
|
| 231 |
+
"112997",
|
| 232 |
+
"113046",
|
| 233 |
+
"114058",
|
| 234 |
+
"114128",
|
| 235 |
+
"114266",
|
| 236 |
+
"114990",
|
| 237 |
+
"115588"
|
| 238 |
+
],
|
| 239 |
+
"n_train": 73,
|
| 240 |
+
"n_val": 19
|
| 241 |
+
},
|
| 242 |
+
"fold_2": {
|
| 243 |
+
"train_patients": [
|
| 244 |
+
"101627",
|
| 245 |
+
"102035",
|
| 246 |
+
"102313",
|
| 247 |
+
"104420",
|
| 248 |
+
"104447",
|
| 249 |
+
"104810",
|
| 250 |
+
"104871",
|
| 251 |
+
"104899",
|
| 252 |
+
"104937",
|
| 253 |
+
"105465",
|
| 254 |
+
"105549",
|
| 255 |
+
"105597",
|
| 256 |
+
"105911",
|
| 257 |
+
"106063",
|
| 258 |
+
"106200",
|
| 259 |
+
"106270",
|
| 260 |
+
"106506",
|
| 261 |
+
"106780",
|
| 262 |
+
"106976",
|
| 263 |
+
"107233",
|
| 264 |
+
"107455",
|
| 265 |
+
"107508",
|
| 266 |
+
"107630",
|
| 267 |
+
"107680",
|
| 268 |
+
"107739",
|
| 269 |
+
"107966",
|
| 270 |
+
"108444",
|
| 271 |
+
"108807",
|
| 272 |
+
"108975",
|
| 273 |
+
"109141",
|
| 274 |
+
"109395",
|
| 275 |
+
"109654",
|
| 276 |
+
"109816",
|
| 277 |
+
"109923",
|
| 278 |
+
"109944",
|
| 279 |
+
"110012",
|
| 280 |
+
"110157",
|
| 281 |
+
"110218",
|
| 282 |
+
"110280",
|
| 283 |
+
"110327",
|
| 284 |
+
"110497",
|
| 285 |
+
"110784",
|
| 286 |
+
"111008",
|
| 287 |
+
"111489",
|
| 288 |
+
"111852",
|
| 289 |
+
"112055",
|
| 290 |
+
"112378",
|
| 291 |
+
"112414",
|
| 292 |
+
"112657",
|
| 293 |
+
"112659",
|
| 294 |
+
"112730",
|
| 295 |
+
"112765",
|
| 296 |
+
"112997",
|
| 297 |
+
"113046",
|
| 298 |
+
"113845",
|
| 299 |
+
"114058",
|
| 300 |
+
"114128",
|
| 301 |
+
"114266",
|
| 302 |
+
"114304",
|
| 303 |
+
"114454",
|
| 304 |
+
"114585",
|
| 305 |
+
"114770",
|
| 306 |
+
"114990",
|
| 307 |
+
"115588",
|
| 308 |
+
"115628",
|
| 309 |
+
"c01p02",
|
| 310 |
+
"c01p03",
|
| 311 |
+
"c01p04",
|
| 312 |
+
"c01p05",
|
| 313 |
+
"c07p02",
|
| 314 |
+
"c07p04",
|
| 315 |
+
"c08p01",
|
| 316 |
+
"c08p02",
|
| 317 |
+
"c08p05"
|
| 318 |
+
],
|
| 319 |
+
"val_patients": [
|
| 320 |
+
"104280",
|
| 321 |
+
"104453",
|
| 322 |
+
"104474",
|
| 323 |
+
"104797",
|
| 324 |
+
"105302",
|
| 325 |
+
"105755",
|
| 326 |
+
"105917",
|
| 327 |
+
"106536",
|
| 328 |
+
"106639",
|
| 329 |
+
"106905",
|
| 330 |
+
"108295",
|
| 331 |
+
"108726",
|
| 332 |
+
"110540",
|
| 333 |
+
"114525",
|
| 334 |
+
"114903",
|
| 335 |
+
"115788",
|
| 336 |
+
"c07p03",
|
| 337 |
+
"c07p05"
|
| 338 |
+
],
|
| 339 |
+
"n_train": 74,
|
| 340 |
+
"n_val": 18
|
| 341 |
+
},
|
| 342 |
+
"fold_3": {
|
| 343 |
+
"train_patients": [
|
| 344 |
+
"101627",
|
| 345 |
+
"102035",
|
| 346 |
+
"102313",
|
| 347 |
+
"104280",
|
| 348 |
+
"104420",
|
| 349 |
+
"104447",
|
| 350 |
+
"104453",
|
| 351 |
+
"104474",
|
| 352 |
+
"104797",
|
| 353 |
+
"104810",
|
| 354 |
+
"104871",
|
| 355 |
+
"104899",
|
| 356 |
+
"104937",
|
| 357 |
+
"105302",
|
| 358 |
+
"105465",
|
| 359 |
+
"105549",
|
| 360 |
+
"105597",
|
| 361 |
+
"105755",
|
| 362 |
+
"105911",
|
| 363 |
+
"105917",
|
| 364 |
+
"106063",
|
| 365 |
+
"106200",
|
| 366 |
+
"106270",
|
| 367 |
+
"106506",
|
| 368 |
+
"106536",
|
| 369 |
+
"106639",
|
| 370 |
+
"106780",
|
| 371 |
+
"106905",
|
| 372 |
+
"106976",
|
| 373 |
+
"107233",
|
| 374 |
+
"107455",
|
| 375 |
+
"107508",
|
| 376 |
+
"107680",
|
| 377 |
+
"107739",
|
| 378 |
+
"107966",
|
| 379 |
+
"108295",
|
| 380 |
+
"108444",
|
| 381 |
+
"108726",
|
| 382 |
+
"108807",
|
| 383 |
+
"108975",
|
| 384 |
+
"109395",
|
| 385 |
+
"109816",
|
| 386 |
+
"109923",
|
| 387 |
+
"110012",
|
| 388 |
+
"110218",
|
| 389 |
+
"110327",
|
| 390 |
+
"110497",
|
| 391 |
+
"110540",
|
| 392 |
+
"111008",
|
| 393 |
+
"112378",
|
| 394 |
+
"112414",
|
| 395 |
+
"112659",
|
| 396 |
+
"112730",
|
| 397 |
+
"112997",
|
| 398 |
+
"113046",
|
| 399 |
+
"113845",
|
| 400 |
+
"114058",
|
| 401 |
+
"114128",
|
| 402 |
+
"114266",
|
| 403 |
+
"114304",
|
| 404 |
+
"114454",
|
| 405 |
+
"114525",
|
| 406 |
+
"114585",
|
| 407 |
+
"114903",
|
| 408 |
+
"114990",
|
| 409 |
+
"115588",
|
| 410 |
+
"115628",
|
| 411 |
+
"115788",
|
| 412 |
+
"c01p03",
|
| 413 |
+
"c01p04",
|
| 414 |
+
"c07p02",
|
| 415 |
+
"c07p03",
|
| 416 |
+
"c07p05",
|
| 417 |
+
"c08p01"
|
| 418 |
+
],
|
| 419 |
+
"val_patients": [
|
| 420 |
+
"107630",
|
| 421 |
+
"109141",
|
| 422 |
+
"109654",
|
| 423 |
+
"109944",
|
| 424 |
+
"110157",
|
| 425 |
+
"110280",
|
| 426 |
+
"110784",
|
| 427 |
+
"111489",
|
| 428 |
+
"111852",
|
| 429 |
+
"112055",
|
| 430 |
+
"112657",
|
| 431 |
+
"112765",
|
| 432 |
+
"114770",
|
| 433 |
+
"c01p02",
|
| 434 |
+
"c01p05",
|
| 435 |
+
"c07p04",
|
| 436 |
+
"c08p02",
|
| 437 |
+
"c08p05"
|
| 438 |
+
],
|
| 439 |
+
"n_train": 74,
|
| 440 |
+
"n_val": 18
|
| 441 |
+
},
|
| 442 |
+
"fold_4": {
|
| 443 |
+
"train_patients": [
|
| 444 |
+
"101627",
|
| 445 |
+
"104280",
|
| 446 |
+
"104420",
|
| 447 |
+
"104447",
|
| 448 |
+
"104453",
|
| 449 |
+
"104474",
|
| 450 |
+
"104797",
|
| 451 |
+
"104810",
|
| 452 |
+
"104871",
|
| 453 |
+
"104899",
|
| 454 |
+
"104937",
|
| 455 |
+
"105302",
|
| 456 |
+
"105549",
|
| 457 |
+
"105597",
|
| 458 |
+
"105755",
|
| 459 |
+
"105911",
|
| 460 |
+
"105917",
|
| 461 |
+
"106270",
|
| 462 |
+
"106536",
|
| 463 |
+
"106639",
|
| 464 |
+
"106780",
|
| 465 |
+
"106905",
|
| 466 |
+
"106976",
|
| 467 |
+
"107455",
|
| 468 |
+
"107508",
|
| 469 |
+
"107630",
|
| 470 |
+
"107680",
|
| 471 |
+
"107739",
|
| 472 |
+
"107966",
|
| 473 |
+
"108295",
|
| 474 |
+
"108726",
|
| 475 |
+
"108807",
|
| 476 |
+
"108975",
|
| 477 |
+
"109141",
|
| 478 |
+
"109395",
|
| 479 |
+
"109654",
|
| 480 |
+
"109816",
|
| 481 |
+
"109923",
|
| 482 |
+
"109944",
|
| 483 |
+
"110012",
|
| 484 |
+
"110157",
|
| 485 |
+
"110218",
|
| 486 |
+
"110280",
|
| 487 |
+
"110540",
|
| 488 |
+
"110784",
|
| 489 |
+
"111008",
|
| 490 |
+
"111489",
|
| 491 |
+
"111852",
|
| 492 |
+
"112055",
|
| 493 |
+
"112657",
|
| 494 |
+
"112659",
|
| 495 |
+
"112765",
|
| 496 |
+
"112997",
|
| 497 |
+
"113046",
|
| 498 |
+
"113845",
|
| 499 |
+
"114058",
|
| 500 |
+
"114128",
|
| 501 |
+
"114266",
|
| 502 |
+
"114454",
|
| 503 |
+
"114525",
|
| 504 |
+
"114770",
|
| 505 |
+
"114903",
|
| 506 |
+
"114990",
|
| 507 |
+
"115588",
|
| 508 |
+
"115788",
|
| 509 |
+
"c01p02",
|
| 510 |
+
"c01p04",
|
| 511 |
+
"c01p05",
|
| 512 |
+
"c07p03",
|
| 513 |
+
"c07p04",
|
| 514 |
+
"c07p05",
|
| 515 |
+
"c08p01",
|
| 516 |
+
"c08p02",
|
| 517 |
+
"c08p05"
|
| 518 |
+
],
|
| 519 |
+
"val_patients": [
|
| 520 |
+
"102035",
|
| 521 |
+
"102313",
|
| 522 |
+
"105465",
|
| 523 |
+
"106063",
|
| 524 |
+
"106200",
|
| 525 |
+
"106506",
|
| 526 |
+
"107233",
|
| 527 |
+
"108444",
|
| 528 |
+
"110327",
|
| 529 |
+
"110497",
|
| 530 |
+
"112378",
|
| 531 |
+
"112414",
|
| 532 |
+
"112730",
|
| 533 |
+
"114304",
|
| 534 |
+
"114585",
|
| 535 |
+
"115628",
|
| 536 |
+
"c01p03",
|
| 537 |
+
"c07p02"
|
| 538 |
+
],
|
| 539 |
+
"n_train": 74,
|
| 540 |
+
"n_val": 18
|
| 541 |
+
}
|
| 542 |
+
}
|
| 543 |
+
}
|
models/for_WMH_Vent/data_splits/for_assignment.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import numpy as np
|
| 4 |
+
from sklearn.model_selection import KFold
|
| 5 |
+
|
| 6 |
+
# ─────────────────────────────────────────────
|
| 7 |
+
# Patient IDs
|
| 8 |
+
# ─────────────────────────────────────────────
|
| 9 |
+
local_patients_id = [
|
| 10 |
+
'101228', '101627', '102035', '102313', '104252', '104280', '104420',
|
| 11 |
+
'104447', '104453', '104474', '104518', '104520', '104670', '104797',
|
| 12 |
+
'104810', '104871', '104899', '104937', '105074', '105302', '105465',
|
| 13 |
+
'105549', '105597', '105755', '105911', '105917', '105978', '106063',
|
| 14 |
+
'106200', '106270', '106506', '106536', '106639', '106780', '106905',
|
| 15 |
+
'106976', '107130', '107233', '107455', '107508', '107539', '107630',
|
| 16 |
+
'107680', '107739', '107966', '107997', '108295', '108344', '108444',
|
| 17 |
+
'108726', '108807', '108975', '109141', '109267', '109395', '109654',
|
| 18 |
+
'109816', '109923', '109944', '110012', '110157', '110218', '110280',
|
| 19 |
+
'110327', '110497', '110540', '110543', '110784', '111008', '111140',
|
| 20 |
+
'111189', '111489', '111691', '111852', '112055', '112378', '112414',
|
| 21 |
+
'112657', '112659', '112730', '112765', '112776', '112997', '113046',
|
| 22 |
+
'113394', '113845', '114058', '114128', '114266', '114304', '114454',
|
| 23 |
+
'114525', '114585', '114770', '114836', '114903', '114990', '115588',
|
| 24 |
+
'115628', '115788',
|
| 25 |
+
]
|
| 26 |
+
|
| 27 |
+
public_patients_id = [
|
| 28 |
+
'c01p01', 'c01p02', 'c01p03', 'c01p04', 'c01p05',
|
| 29 |
+
'c07p01', 'c07p02', 'c07p03', 'c07p04', 'c07p05',
|
| 30 |
+
'c08p01', 'c08p02', 'c08p03', 'c08p04', 'c08p05',
|
| 31 |
+
]
|
| 32 |
+
|
| 33 |
+
RANDOM_SEED = 42
|
| 34 |
+
N_FOLDS = 4
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 38 |
+
# make_folds_exact (LOCAL)
|
| 39 |
+
# Carves n_val_per_fold * n_folds patients as an exclusive val pool,
|
| 40 |
+
# then rotates the val window. Val sets are perfectly non-overlapping.
|
| 41 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 42 |
+
def make_folds_exact(trainval, n_val_per_fold, n_folds, rng):
|
| 43 |
+
arr = np.array(trainval)
|
| 44 |
+
rng.shuffle(arr)
|
| 45 |
+
|
| 46 |
+
total_val_pool = n_folds * n_val_per_fold # 5 * 10 = 50
|
| 47 |
+
assert total_val_pool <= len(arr), (
|
| 48 |
+
f"Not enough trainval ({len(arr)}) for {n_folds} x {n_val_per_fold} val = {total_val_pool}"
|
| 49 |
+
)
|
| 50 |
+
val_pool = arr[:total_val_pool] # 50 dedicated val patients
|
| 51 |
+
train_base = arr[total_val_pool:] # 29 always-train patients
|
| 52 |
+
|
| 53 |
+
folds = {}
|
| 54 |
+
for fold_idx in range(n_folds):
|
| 55 |
+
val_pts = val_pool[fold_idx * n_val_per_fold:(fold_idx + 1) * n_val_per_fold].tolist()
|
| 56 |
+
other_val = np.concatenate([
|
| 57 |
+
val_pool[:fold_idx * n_val_per_fold],
|
| 58 |
+
val_pool[(fold_idx + 1) * n_val_per_fold:]
|
| 59 |
+
])
|
| 60 |
+
train_pts = np.concatenate([other_val, train_base]).tolist()
|
| 61 |
+
folds[f"fold_{fold_idx}"] = {
|
| 62 |
+
"train_patients": train_pts,
|
| 63 |
+
"val_patients": val_pts,
|
| 64 |
+
"n_train": len(train_pts),
|
| 65 |
+
"n_val": len(val_pts),
|
| 66 |
+
}
|
| 67 |
+
return folds
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 71 |
+
# make_folds_kfold (PUBLIC)
|
| 72 |
+
# With only 12 trainval patients and 5 folds, KFold is the only way to keep
|
| 73 |
+
# val sets strictly non-overlapping. Val sizes will be 3,3,2,2,2.
|
| 74 |
+
# (5 * 3 = 15 > 12, so exact 3 per fold is mathematically impossible without
|
| 75 |
+
# overlap; KFold is the standard, correct solution.)
|
| 76 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 77 |
+
def make_folds_kfold(trainval, n_folds, rng):
|
| 78 |
+
arr = np.array(trainval)
|
| 79 |
+
rng.shuffle(arr)
|
| 80 |
+
|
| 81 |
+
kf = KFold(n_splits=n_folds, shuffle=False) # arr already shuffled
|
| 82 |
+
folds = {}
|
| 83 |
+
for fold_idx, (train_idx, val_idx) in enumerate(kf.split(arr)):
|
| 84 |
+
folds[f"fold_{fold_idx}"] = {
|
| 85 |
+
"train_patients": arr[train_idx].tolist(),
|
| 86 |
+
"val_patients": arr[val_idx].tolist(),
|
| 87 |
+
"n_train": len(train_idx),
|
| 88 |
+
"n_val": len(val_idx),
|
| 89 |
+
}
|
| 90 |
+
return folds
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
# ─────────────────────────────────────────────────────────────────────────���───
|
| 94 |
+
# LOCAL -- 70 / 10 / 20
|
| 95 |
+
# 99 total -> test=20, val=10 per fold, train=69 per fold
|
| 96 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 97 |
+
n_local = len(local_patients_id) # 99
|
| 98 |
+
n_local_test = round(n_local * 0.20) # 20
|
| 99 |
+
n_local_val_per_fold = round(n_local * 0.10) # 10
|
| 100 |
+
|
| 101 |
+
rng_local = np.random.default_rng(RANDOM_SEED)
|
| 102 |
+
local_arr = np.array(local_patients_id)
|
| 103 |
+
rng_local.shuffle(local_arr)
|
| 104 |
+
|
| 105 |
+
local_test = local_arr[:n_local_test].tolist() # 20
|
| 106 |
+
local_trainval = local_arr[n_local_test:].tolist() # 79
|
| 107 |
+
|
| 108 |
+
local_folds = make_folds_exact(
|
| 109 |
+
local_trainval,
|
| 110 |
+
n_val_per_fold=n_local_val_per_fold,
|
| 111 |
+
n_folds=N_FOLDS,
|
| 112 |
+
rng=np.random.default_rng(RANDOM_SEED + 1),
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
local_split = {
|
| 116 |
+
"metadata": {
|
| 117 |
+
"dataset": "Local_SAI",
|
| 118 |
+
"total_patients": n_local,
|
| 119 |
+
"test_patients": n_local_test,
|
| 120 |
+
"trainval_patients": len(local_trainval),
|
| 121 |
+
"target_split": "70/10/20 (train/val/test)",
|
| 122 |
+
"exact_counts": "train=69, val=10, test=20 per fold",
|
| 123 |
+
"n_folds": N_FOLDS,
|
| 124 |
+
"random_seed": RANDOM_SEED,
|
| 125 |
+
},
|
| 126 |
+
"test_set": {"patients": local_test, "n_patients": n_local_test},
|
| 127 |
+
"folds": local_folds,
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 131 |
+
# PUBLIC -- 60 / 20 / 20
|
| 132 |
+
# 15 total -> test=3 (center-balanced), trainval=12
|
| 133 |
+
# KFold(5) on 12 -> val sizes: 3,3,2,2,2 (non-overlapping, closest to 20%)
|
| 134 |
+
# train sizes: 9,9,10,10,10
|
| 135 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 136 |
+
n_public = len(public_patients_id) # 15
|
| 137 |
+
|
| 138 |
+
# Center-balanced test: 1 patient per center
|
| 139 |
+
centers = {}
|
| 140 |
+
for pid in public_patients_id:
|
| 141 |
+
centers.setdefault(pid[:3], []).append(pid)
|
| 142 |
+
|
| 143 |
+
public_test = []
|
| 144 |
+
public_trainval = []
|
| 145 |
+
for center, pids in sorted(centers.items()):
|
| 146 |
+
arr = np.array(pids)
|
| 147 |
+
np.random.default_rng(RANDOM_SEED + hash(center) % 1000).shuffle(arr)
|
| 148 |
+
public_test.append(arr[0]) # 1 test per center -> 3 total
|
| 149 |
+
public_trainval += arr[1:].tolist() # 4 trainval per center -> 12 total
|
| 150 |
+
|
| 151 |
+
public_folds = make_folds_kfold(
|
| 152 |
+
public_trainval,
|
| 153 |
+
n_folds=N_FOLDS,
|
| 154 |
+
rng=np.random.default_rng(RANDOM_SEED + 2),
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
public_split = {
|
| 158 |
+
"metadata": {
|
| 159 |
+
"dataset": "Public_MSSEG",
|
| 160 |
+
"total_patients": n_public,
|
| 161 |
+
"test_patients": len(public_test),
|
| 162 |
+
"trainval_patients": len(public_trainval),
|
| 163 |
+
"target_split": "60/20/20 (train/val/test)",
|
| 164 |
+
"n_folds": N_FOLDS,
|
| 165 |
+
"random_seed": RANDOM_SEED,
|
| 166 |
+
"center_balanced_test": True,
|
| 167 |
+
},
|
| 168 |
+
"test_set": {"patients": public_test, "n_patients": len(public_test)},
|
| 169 |
+
"folds": public_folds,
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 173 |
+
# CONCATENATED
|
| 174 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 175 |
+
concat_test = local_test + public_test
|
| 176 |
+
concat_folds = {}
|
| 177 |
+
for fold_key in local_folds:
|
| 178 |
+
lf = local_folds[fold_key]
|
| 179 |
+
pf = public_folds[fold_key]
|
| 180 |
+
concat_folds[fold_key] = {
|
| 181 |
+
"train_patients": lf["train_patients"] + pf["train_patients"],
|
| 182 |
+
"val_patients": lf["val_patients"] + pf["val_patients"],
|
| 183 |
+
"n_train": lf["n_train"] + pf["n_train"],
|
| 184 |
+
"n_val": lf["n_val"] + pf["n_val"],
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
concat_split = {
|
| 188 |
+
"metadata": {
|
| 189 |
+
"datasets": ["Local_SAI", "Public_MSSEG"],
|
| 190 |
+
"total_patients": n_local + n_public,
|
| 191 |
+
"test_patients": len(concat_test),
|
| 192 |
+
"trainval_patients": len(local_trainval) + len(public_trainval),
|
| 193 |
+
"local_split": "70/10/20",
|
| 194 |
+
"public_split": "60/20/20",
|
| 195 |
+
"n_folds": N_FOLDS,
|
| 196 |
+
"random_seed": RANDOM_SEED,
|
| 197 |
+
},
|
| 198 |
+
"test_set": {"patients": concat_test, "n_patients": len(concat_test)},
|
| 199 |
+
"folds": concat_folds,
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 203 |
+
# Save
|
| 204 |
+
# ───────────────��─────────────────────────────────────────────────────────────
|
| 205 |
+
output_dir = os.path.dirname(os.path.abspath(__file__))
|
| 206 |
+
|
| 207 |
+
for name, data in [
|
| 208 |
+
("local_fold_assignments.json", local_split),
|
| 209 |
+
("public_fold_assignments.json", public_split),
|
| 210 |
+
("concat_fold_assignments.json", concat_split),
|
| 211 |
+
]:
|
| 212 |
+
path = os.path.join(output_dir, name)
|
| 213 |
+
with open(path, "w") as f:
|
| 214 |
+
json.dump(data, f, indent=2)
|
| 215 |
+
print(f"Saved: {path}")
|
| 216 |
+
|
| 217 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 218 |
+
# Sanity check
|
| 219 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 220 |
+
print("\n=== SANITY CHECK ===")
|
| 221 |
+
for label, split_data in [("LOCAL", local_split), ("PUBLIC", public_split), ("CONCAT", concat_split)]:
|
| 222 |
+
test_pts = set(split_data["test_set"]["patients"])
|
| 223 |
+
print(f"\n{label} (test={len(test_pts)})")
|
| 224 |
+
val_sets = []
|
| 225 |
+
for fold_key, fold in split_data["folds"].items():
|
| 226 |
+
train_pts = set(fold["train_patients"])
|
| 227 |
+
val_pts = set(fold["val_patients"])
|
| 228 |
+
val_sets.append(val_pts)
|
| 229 |
+
tv_overlap = len(train_pts & val_pts)
|
| 230 |
+
tst_overlap = len((train_pts | val_pts) & test_pts)
|
| 231 |
+
print(f" {fold_key}: train={len(train_pts):3d}, val={len(val_pts):2d} | "
|
| 232 |
+
f"train/val overlap={tv_overlap} | (train+val)/test overlap={tst_overlap}")
|
| 233 |
+
bad = [f"f{i}&f{j}" for i in range(len(val_sets)) for j in range(i+1, len(val_sets)) if val_sets[i] & val_sets[j]]
|
| 234 |
+
print(f" Val sets unique across folds: {'FAIL: ' + str(bad) if bad else 'OK'}")
|
models/for_WMH_Vent/data_splits/local_fold_assignments.json
ADDED
|
@@ -0,0 +1,421 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"metadata": {
|
| 3 |
+
"dataset": "Local_SAI",
|
| 4 |
+
"total_patients": 100,
|
| 5 |
+
"test_patients": 10,
|
| 6 |
+
"trainval_patients": 90,
|
| 7 |
+
"target_split": "70/10/20 (train/val/test)",
|
| 8 |
+
"exact_counts": "train=70, val=10, test=20 per fold",
|
| 9 |
+
"n_folds": 4,
|
| 10 |
+
"random_seed": 42
|
| 11 |
+
},
|
| 12 |
+
"test_set": {
|
| 13 |
+
"patients": [
|
| 14 |
+
"110012",
|
| 15 |
+
"105549",
|
| 16 |
+
"109816",
|
| 17 |
+
"105074",
|
| 18 |
+
"106780",
|
| 19 |
+
"107680",
|
| 20 |
+
"108807",
|
| 21 |
+
"106063",
|
| 22 |
+
"114585",
|
| 23 |
+
"111489"
|
| 24 |
+
],
|
| 25 |
+
"n_patients": 10
|
| 26 |
+
},
|
| 27 |
+
"folds": {
|
| 28 |
+
"fold_0": {
|
| 29 |
+
"train_patients": [
|
| 30 |
+
"109395",
|
| 31 |
+
"115788",
|
| 32 |
+
"113845",
|
| 33 |
+
"114770",
|
| 34 |
+
"102313",
|
| 35 |
+
"104797",
|
| 36 |
+
"111189",
|
| 37 |
+
"105597",
|
| 38 |
+
"111140",
|
| 39 |
+
"106270",
|
| 40 |
+
"114836",
|
| 41 |
+
"108295",
|
| 42 |
+
"104518",
|
| 43 |
+
"110218",
|
| 44 |
+
"110784",
|
| 45 |
+
"101627",
|
| 46 |
+
"104280",
|
| 47 |
+
"107966",
|
| 48 |
+
"101228",
|
| 49 |
+
"104420",
|
| 50 |
+
"109944",
|
| 51 |
+
"114903",
|
| 52 |
+
"112765",
|
| 53 |
+
"106200",
|
| 54 |
+
"106506",
|
| 55 |
+
"106536",
|
| 56 |
+
"112055",
|
| 57 |
+
"104447",
|
| 58 |
+
"106976",
|
| 59 |
+
"105978",
|
| 60 |
+
"110543",
|
| 61 |
+
"114058",
|
| 62 |
+
"113394",
|
| 63 |
+
"107739",
|
| 64 |
+
"112657",
|
| 65 |
+
"111008",
|
| 66 |
+
"105911",
|
| 67 |
+
"111852",
|
| 68 |
+
"105465",
|
| 69 |
+
"114128",
|
| 70 |
+
"110280",
|
| 71 |
+
"112414",
|
| 72 |
+
"105302",
|
| 73 |
+
"107455",
|
| 74 |
+
"110327",
|
| 75 |
+
"114990",
|
| 76 |
+
"112730",
|
| 77 |
+
"104453",
|
| 78 |
+
"111691",
|
| 79 |
+
"114454",
|
| 80 |
+
"104474",
|
| 81 |
+
"104252",
|
| 82 |
+
"109654",
|
| 83 |
+
"104937",
|
| 84 |
+
"104871",
|
| 85 |
+
"107508",
|
| 86 |
+
"114525",
|
| 87 |
+
"115588",
|
| 88 |
+
"110540",
|
| 89 |
+
"109267",
|
| 90 |
+
"107539",
|
| 91 |
+
"108344",
|
| 92 |
+
"112659",
|
| 93 |
+
"112776",
|
| 94 |
+
"113046",
|
| 95 |
+
"107233",
|
| 96 |
+
"102035",
|
| 97 |
+
"106905",
|
| 98 |
+
"107997",
|
| 99 |
+
"112378",
|
| 100 |
+
"104520",
|
| 101 |
+
"106639",
|
| 102 |
+
"104670",
|
| 103 |
+
"104899",
|
| 104 |
+
"115628",
|
| 105 |
+
"108444",
|
| 106 |
+
"109923",
|
| 107 |
+
"110157",
|
| 108 |
+
"114304",
|
| 109 |
+
"114266"
|
| 110 |
+
],
|
| 111 |
+
"val_patients": [
|
| 112 |
+
"108726",
|
| 113 |
+
"105917",
|
| 114 |
+
"105755",
|
| 115 |
+
"109141",
|
| 116 |
+
"110497",
|
| 117 |
+
"112997",
|
| 118 |
+
"104810",
|
| 119 |
+
"108975",
|
| 120 |
+
"107130",
|
| 121 |
+
"107630"
|
| 122 |
+
],
|
| 123 |
+
"n_train": 80,
|
| 124 |
+
"n_val": 10
|
| 125 |
+
},
|
| 126 |
+
"fold_1": {
|
| 127 |
+
"train_patients": [
|
| 128 |
+
"108726",
|
| 129 |
+
"105917",
|
| 130 |
+
"105755",
|
| 131 |
+
"109141",
|
| 132 |
+
"110497",
|
| 133 |
+
"112997",
|
| 134 |
+
"104810",
|
| 135 |
+
"108975",
|
| 136 |
+
"107130",
|
| 137 |
+
"107630",
|
| 138 |
+
"114836",
|
| 139 |
+
"108295",
|
| 140 |
+
"104518",
|
| 141 |
+
"110218",
|
| 142 |
+
"110784",
|
| 143 |
+
"101627",
|
| 144 |
+
"104280",
|
| 145 |
+
"107966",
|
| 146 |
+
"101228",
|
| 147 |
+
"104420",
|
| 148 |
+
"109944",
|
| 149 |
+
"114903",
|
| 150 |
+
"112765",
|
| 151 |
+
"106200",
|
| 152 |
+
"106506",
|
| 153 |
+
"106536",
|
| 154 |
+
"112055",
|
| 155 |
+
"104447",
|
| 156 |
+
"106976",
|
| 157 |
+
"105978",
|
| 158 |
+
"110543",
|
| 159 |
+
"114058",
|
| 160 |
+
"113394",
|
| 161 |
+
"107739",
|
| 162 |
+
"112657",
|
| 163 |
+
"111008",
|
| 164 |
+
"105911",
|
| 165 |
+
"111852",
|
| 166 |
+
"105465",
|
| 167 |
+
"114128",
|
| 168 |
+
"110280",
|
| 169 |
+
"112414",
|
| 170 |
+
"105302",
|
| 171 |
+
"107455",
|
| 172 |
+
"110327",
|
| 173 |
+
"114990",
|
| 174 |
+
"112730",
|
| 175 |
+
"104453",
|
| 176 |
+
"111691",
|
| 177 |
+
"114454",
|
| 178 |
+
"104474",
|
| 179 |
+
"104252",
|
| 180 |
+
"109654",
|
| 181 |
+
"104937",
|
| 182 |
+
"104871",
|
| 183 |
+
"107508",
|
| 184 |
+
"114525",
|
| 185 |
+
"115588",
|
| 186 |
+
"110540",
|
| 187 |
+
"109267",
|
| 188 |
+
"107539",
|
| 189 |
+
"108344",
|
| 190 |
+
"112659",
|
| 191 |
+
"112776",
|
| 192 |
+
"113046",
|
| 193 |
+
"107233",
|
| 194 |
+
"102035",
|
| 195 |
+
"106905",
|
| 196 |
+
"107997",
|
| 197 |
+
"112378",
|
| 198 |
+
"104520",
|
| 199 |
+
"106639",
|
| 200 |
+
"104670",
|
| 201 |
+
"104899",
|
| 202 |
+
"115628",
|
| 203 |
+
"108444",
|
| 204 |
+
"109923",
|
| 205 |
+
"110157",
|
| 206 |
+
"114304",
|
| 207 |
+
"114266"
|
| 208 |
+
],
|
| 209 |
+
"val_patients": [
|
| 210 |
+
"109395",
|
| 211 |
+
"115788",
|
| 212 |
+
"113845",
|
| 213 |
+
"114770",
|
| 214 |
+
"102313",
|
| 215 |
+
"104797",
|
| 216 |
+
"111189",
|
| 217 |
+
"105597",
|
| 218 |
+
"111140",
|
| 219 |
+
"106270"
|
| 220 |
+
],
|
| 221 |
+
"n_train": 80,
|
| 222 |
+
"n_val": 10
|
| 223 |
+
},
|
| 224 |
+
"fold_2": {
|
| 225 |
+
"train_patients": [
|
| 226 |
+
"108726",
|
| 227 |
+
"105917",
|
| 228 |
+
"105755",
|
| 229 |
+
"109141",
|
| 230 |
+
"110497",
|
| 231 |
+
"112997",
|
| 232 |
+
"104810",
|
| 233 |
+
"108975",
|
| 234 |
+
"107130",
|
| 235 |
+
"107630",
|
| 236 |
+
"109395",
|
| 237 |
+
"115788",
|
| 238 |
+
"113845",
|
| 239 |
+
"114770",
|
| 240 |
+
"102313",
|
| 241 |
+
"104797",
|
| 242 |
+
"111189",
|
| 243 |
+
"105597",
|
| 244 |
+
"111140",
|
| 245 |
+
"106270",
|
| 246 |
+
"109944",
|
| 247 |
+
"114903",
|
| 248 |
+
"112765",
|
| 249 |
+
"106200",
|
| 250 |
+
"106506",
|
| 251 |
+
"106536",
|
| 252 |
+
"112055",
|
| 253 |
+
"104447",
|
| 254 |
+
"106976",
|
| 255 |
+
"105978",
|
| 256 |
+
"110543",
|
| 257 |
+
"114058",
|
| 258 |
+
"113394",
|
| 259 |
+
"107739",
|
| 260 |
+
"112657",
|
| 261 |
+
"111008",
|
| 262 |
+
"105911",
|
| 263 |
+
"111852",
|
| 264 |
+
"105465",
|
| 265 |
+
"114128",
|
| 266 |
+
"110280",
|
| 267 |
+
"112414",
|
| 268 |
+
"105302",
|
| 269 |
+
"107455",
|
| 270 |
+
"110327",
|
| 271 |
+
"114990",
|
| 272 |
+
"112730",
|
| 273 |
+
"104453",
|
| 274 |
+
"111691",
|
| 275 |
+
"114454",
|
| 276 |
+
"104474",
|
| 277 |
+
"104252",
|
| 278 |
+
"109654",
|
| 279 |
+
"104937",
|
| 280 |
+
"104871",
|
| 281 |
+
"107508",
|
| 282 |
+
"114525",
|
| 283 |
+
"115588",
|
| 284 |
+
"110540",
|
| 285 |
+
"109267",
|
| 286 |
+
"107539",
|
| 287 |
+
"108344",
|
| 288 |
+
"112659",
|
| 289 |
+
"112776",
|
| 290 |
+
"113046",
|
| 291 |
+
"107233",
|
| 292 |
+
"102035",
|
| 293 |
+
"106905",
|
| 294 |
+
"107997",
|
| 295 |
+
"112378",
|
| 296 |
+
"104520",
|
| 297 |
+
"106639",
|
| 298 |
+
"104670",
|
| 299 |
+
"104899",
|
| 300 |
+
"115628",
|
| 301 |
+
"108444",
|
| 302 |
+
"109923",
|
| 303 |
+
"110157",
|
| 304 |
+
"114304",
|
| 305 |
+
"114266"
|
| 306 |
+
],
|
| 307 |
+
"val_patients": [
|
| 308 |
+
"114836",
|
| 309 |
+
"108295",
|
| 310 |
+
"104518",
|
| 311 |
+
"110218",
|
| 312 |
+
"110784",
|
| 313 |
+
"101627",
|
| 314 |
+
"104280",
|
| 315 |
+
"107966",
|
| 316 |
+
"101228",
|
| 317 |
+
"104420"
|
| 318 |
+
],
|
| 319 |
+
"n_train": 80,
|
| 320 |
+
"n_val": 10
|
| 321 |
+
},
|
| 322 |
+
"fold_3": {
|
| 323 |
+
"train_patients": [
|
| 324 |
+
"108726",
|
| 325 |
+
"105917",
|
| 326 |
+
"105755",
|
| 327 |
+
"109141",
|
| 328 |
+
"110497",
|
| 329 |
+
"112997",
|
| 330 |
+
"104810",
|
| 331 |
+
"108975",
|
| 332 |
+
"107130",
|
| 333 |
+
"107630",
|
| 334 |
+
"109395",
|
| 335 |
+
"115788",
|
| 336 |
+
"113845",
|
| 337 |
+
"114770",
|
| 338 |
+
"102313",
|
| 339 |
+
"104797",
|
| 340 |
+
"111189",
|
| 341 |
+
"105597",
|
| 342 |
+
"111140",
|
| 343 |
+
"106270",
|
| 344 |
+
"114836",
|
| 345 |
+
"108295",
|
| 346 |
+
"104518",
|
| 347 |
+
"110218",
|
| 348 |
+
"110784",
|
| 349 |
+
"101627",
|
| 350 |
+
"104280",
|
| 351 |
+
"107966",
|
| 352 |
+
"101228",
|
| 353 |
+
"104420",
|
| 354 |
+
"110543",
|
| 355 |
+
"114058",
|
| 356 |
+
"113394",
|
| 357 |
+
"107739",
|
| 358 |
+
"112657",
|
| 359 |
+
"111008",
|
| 360 |
+
"105911",
|
| 361 |
+
"111852",
|
| 362 |
+
"105465",
|
| 363 |
+
"114128",
|
| 364 |
+
"110280",
|
| 365 |
+
"112414",
|
| 366 |
+
"105302",
|
| 367 |
+
"107455",
|
| 368 |
+
"110327",
|
| 369 |
+
"114990",
|
| 370 |
+
"112730",
|
| 371 |
+
"104453",
|
| 372 |
+
"111691",
|
| 373 |
+
"114454",
|
| 374 |
+
"104474",
|
| 375 |
+
"104252",
|
| 376 |
+
"109654",
|
| 377 |
+
"104937",
|
| 378 |
+
"104871",
|
| 379 |
+
"107508",
|
| 380 |
+
"114525",
|
| 381 |
+
"115588",
|
| 382 |
+
"110540",
|
| 383 |
+
"109267",
|
| 384 |
+
"107539",
|
| 385 |
+
"108344",
|
| 386 |
+
"112659",
|
| 387 |
+
"112776",
|
| 388 |
+
"113046",
|
| 389 |
+
"107233",
|
| 390 |
+
"102035",
|
| 391 |
+
"106905",
|
| 392 |
+
"107997",
|
| 393 |
+
"112378",
|
| 394 |
+
"104520",
|
| 395 |
+
"106639",
|
| 396 |
+
"104670",
|
| 397 |
+
"104899",
|
| 398 |
+
"115628",
|
| 399 |
+
"108444",
|
| 400 |
+
"109923",
|
| 401 |
+
"110157",
|
| 402 |
+
"114304",
|
| 403 |
+
"114266"
|
| 404 |
+
],
|
| 405 |
+
"val_patients": [
|
| 406 |
+
"109944",
|
| 407 |
+
"114903",
|
| 408 |
+
"112765",
|
| 409 |
+
"106200",
|
| 410 |
+
"106506",
|
| 411 |
+
"106536",
|
| 412 |
+
"112055",
|
| 413 |
+
"104447",
|
| 414 |
+
"106976",
|
| 415 |
+
"105978"
|
| 416 |
+
],
|
| 417 |
+
"n_train": 80,
|
| 418 |
+
"n_val": 10
|
| 419 |
+
}
|
| 420 |
+
}
|
| 421 |
+
}
|
models/for_WMH_Vent/data_splits/public_fold_assignments.json
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"metadata": {
|
| 3 |
+
"dataset": "Public_MSSEG",
|
| 4 |
+
"total_patients": 15,
|
| 5 |
+
"test_patients": 3,
|
| 6 |
+
"trainval_patients": 12,
|
| 7 |
+
"target_split": "60/20/20 (train/val/test)",
|
| 8 |
+
"n_folds": 4,
|
| 9 |
+
"random_seed": 42,
|
| 10 |
+
"center_balanced_test": true
|
| 11 |
+
},
|
| 12 |
+
"test_set": {
|
| 13 |
+
"patients": [
|
| 14 |
+
"c01p04",
|
| 15 |
+
"c07p05",
|
| 16 |
+
"c08p04"
|
| 17 |
+
],
|
| 18 |
+
"n_patients": 3
|
| 19 |
+
},
|
| 20 |
+
"folds": {
|
| 21 |
+
"fold_0": {
|
| 22 |
+
"train_patients": [
|
| 23 |
+
"c08p03",
|
| 24 |
+
"c01p01",
|
| 25 |
+
"c08p02",
|
| 26 |
+
"c07p03",
|
| 27 |
+
"c07p04",
|
| 28 |
+
"c01p02",
|
| 29 |
+
"c07p01",
|
| 30 |
+
"c08p05",
|
| 31 |
+
"c07p02"
|
| 32 |
+
],
|
| 33 |
+
"val_patients": [
|
| 34 |
+
"c01p05",
|
| 35 |
+
"c08p01",
|
| 36 |
+
"c01p03"
|
| 37 |
+
],
|
| 38 |
+
"n_train": 9,
|
| 39 |
+
"n_val": 3
|
| 40 |
+
},
|
| 41 |
+
"fold_1": {
|
| 42 |
+
"train_patients": [
|
| 43 |
+
"c01p05",
|
| 44 |
+
"c08p01",
|
| 45 |
+
"c01p03",
|
| 46 |
+
"c07p03",
|
| 47 |
+
"c07p04",
|
| 48 |
+
"c01p02",
|
| 49 |
+
"c07p01",
|
| 50 |
+
"c08p05",
|
| 51 |
+
"c07p02"
|
| 52 |
+
],
|
| 53 |
+
"val_patients": [
|
| 54 |
+
"c08p03",
|
| 55 |
+
"c01p01",
|
| 56 |
+
"c08p02"
|
| 57 |
+
],
|
| 58 |
+
"n_train": 9,
|
| 59 |
+
"n_val": 3
|
| 60 |
+
},
|
| 61 |
+
"fold_2": {
|
| 62 |
+
"train_patients": [
|
| 63 |
+
"c01p05",
|
| 64 |
+
"c08p01",
|
| 65 |
+
"c01p03",
|
| 66 |
+
"c08p03",
|
| 67 |
+
"c01p01",
|
| 68 |
+
"c08p02",
|
| 69 |
+
"c07p01",
|
| 70 |
+
"c08p05",
|
| 71 |
+
"c07p02"
|
| 72 |
+
],
|
| 73 |
+
"val_patients": [
|
| 74 |
+
"c07p03",
|
| 75 |
+
"c07p04",
|
| 76 |
+
"c01p02"
|
| 77 |
+
],
|
| 78 |
+
"n_train": 9,
|
| 79 |
+
"n_val": 3
|
| 80 |
+
},
|
| 81 |
+
"fold_3": {
|
| 82 |
+
"train_patients": [
|
| 83 |
+
"c01p05",
|
| 84 |
+
"c08p01",
|
| 85 |
+
"c01p03",
|
| 86 |
+
"c08p03",
|
| 87 |
+
"c01p01",
|
| 88 |
+
"c08p02",
|
| 89 |
+
"c07p03",
|
| 90 |
+
"c07p04",
|
| 91 |
+
"c01p02"
|
| 92 |
+
],
|
| 93 |
+
"val_patients": [
|
| 94 |
+
"c07p01",
|
| 95 |
+
"c08p05",
|
| 96 |
+
"c07p02"
|
| 97 |
+
],
|
| 98 |
+
"n_train": 9,
|
| 99 |
+
"n_val": 3
|
| 100 |
+
}
|
| 101 |
+
}
|
| 102 |
+
}
|
models/for_WMH_Vent/download_models.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Visit our Hugging Face link for downloading the trained models.
|
models/for_WMH_Vent/folds_results_zscore2_all/per_class_summary.csv
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Variant,Variant_Name,Class,Class_Name,DICE_mean,DICE_std,DICE_min,DICE_max,PRECISION_mean,PRECISION_std,PRECISION_min,PRECISION_max,RECALL_mean,RECALL_std,RECALL_min,RECALL_max,IOU_mean,IOU_std,IOU_min,IOU_max,SPECIFICITY_mean,SPECIFICITY_std,SPECIFICITY_min,SPECIFICITY_max,HD95_mean,HD95_std,HD95_min,HD95_max,LESION_SENSITIVITY_mean,LESION_SENSITIVITY_std,LESION_PRECISION_mean,LESION_PRECISION_std,LESION_F1_mean,LESION_F1_std,LESION_N_GT_LESIONS_total,LESION_N_PRED_LESIONS_total,LESION_TP_LESIONS_total,LESION_FN_LESIONS_total,LESION_FP_LESIONS_total
|
| 2 |
+
1,unet,1,Ventricles,0.9296308495604303,0.003051861083997252,0.9245313971595007,0.9325869622190041,0.937810327296536,0.004534371323946414,0.9299792648109558,0.9408155762340407,0.9221807114485115,0.002258280011483868,0.9198278289806433,0.9258708828514376,0.86883963293893,0.005257697310767231,0.8600597389216186,0.8739402369187041,0.9992060262932462,5.3781685628696937e-05,0.9991132143255909,0.9992439887040112,1.0,0.0,1.0,1.0,,,,,,,0.0,0.0,0.0,0.0,0.0
|
| 3 |
+
1,unet,2,Abnormal_WMH,0.8471261192911104,0.006988603634009174,0.8380055641483046,0.8562749203494235,0.8861666894324636,0.004785959852918547,0.8819829027399184,0.8938984076631564,0.8156915305742668,0.008049307835491817,0.8038219006844454,0.8241991234390766,0.7363711260717487,0.01045759631012397,0.7227128240374371,0.7500107261976053,0.9992840254178714,1.4410851399106235e-05,0.9992605553058871,0.9992976281722222,4.579276208116416,0.9906935564361317,3.105706418326917,5.669535448758443,,,,,,,,,,,
|
| 4 |
+
2,attnunet,1,Ventricles,0.9104890513851166,0.024899999222747722,0.8675609526332009,0.9278078258646835,0.9203443411150141,0.02293220806285698,0.8806633355974736,0.9350793260117669,0.9019219497921485,0.026912289770452267,0.8562340091217527,0.923007569878129,0.83718265247343,0.040705456147877725,0.7670303812669205,0.8657288299651152,0.9989795287466985,0.000273307129456502,0.9985083429568933,0.9991706542188458,1.2282992876459566,0.3954259655345788,1.0,1.913197150583827,,,,,,,0.0,0.0,0.0,0.0,0.0
|
| 5 |
+
2,attnunet,2,Abnormal_WMH,0.826975751920205,0.015579775036519973,0.8023896495263613,0.8453085695774089,0.8886304925692461,0.009863625617703263,0.8773743681724526,0.9034203955732797,0.779984526149581,0.01966711140506486,0.7519818345459007,0.8069007799264991,0.7066411846722295,0.022442058844092443,0.6715032946218304,0.7335430311874717,0.9993673813767945,5.564253319795054e-05,0.9993140913333589,0.9994591017880485,5.868210623237481,1.1565233310098124,4.299965705388233,7.125643309860791,,,,,,,,,,,
|
| 6 |
+
3,dlv3unet,1,Ventricles,0.9005661992435416,0.0020867923289419102,0.8974816472833985,0.9032637548534808,0.8997698242284116,0.002619190013214462,0.8961697010761339,0.9031854303942591,0.9018365317555639,0.00304923631516394,0.896838192141597,0.9048159553177055,0.8198187029362412,0.0034773428576675017,0.8147050743404023,0.8243681794688564,0.9987641261362938,3.990839508953154e-05,0.9987224004496448,0.9988270370474228,1.0,0.0,1.0,1.0,,,,,,,0.0,0.0,0.0,0.0,0.0
|
| 7 |
+
3,dlv3unet,2,Abnormal_WMH,0.7763168733853871,0.003073255677872925,0.772352120446274,0.7808450853012623,0.7932948495860329,0.01339398897949029,0.775105014208068,0.8127526621991921,0.7653741420470819,0.01188742959950416,0.7489376058870139,0.7803045906209611,0.6370210758311682,0.003999380279290716,0.6319568324618485,0.6429538443669356,0.9985668433084038,0.00013274507182798207,0.9983882554846853,0.9987628711423673,4.7126929962683395,0.5556289444958085,4.095494923513843,5.423612659365294,,,,,,,,,,,
|
| 8 |
+
4,transunet,1,Ventricles,0.9246872887842248,0.004597522753464204,0.917144392619005,0.9284374594079503,0.9320059959760637,0.011631626135529186,0.9158405109177434,0.9481783808085702,0.9184641125298365,0.004922383784531681,0.9100900076594562,0.9224872999070102,0.8603159951545595,0.007862485177144832,0.8474580455251116,0.8667238350341302,0.9991215386213639,0.00017857472387549319,0.9988581545201223,0.999358864513995,1.0,0.0,1.0,1.0,,,,,,,0.0,0.0,0.0,0.0,0.0
|
| 9 |
+
4,transunet,2,Abnormal_WMH,0.8322919090444327,0.010816310171427137,0.81389122085058,0.8417282856836877,0.9035192038694635,0.003183241810633453,0.8989558813413554,0.9074719810187692,0.7761566255927599,0.015082099936625985,0.7515534359141638,0.7926254513986619,0.7142798166712054,0.01573238533985691,0.6875437726337009,0.7281151773663771,0.9994577459658872,1.5563988190790074e-05,0.9994339623025651,0.9994755539827085,5.929181221900818,1.9288286807668098,4.026591793193768,8.744998558832915,,,,,,,,,,,
|
models/for_WMH_Vent/folds_results_zscore2_all/test_metrics_all_variants_folds.csv
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Variant,Variant_Name,Fold,Test_Samples,DICE_class_1,DICE_class_2,DICE_mean,PRECISION_class_1,PRECISION_class_2,PRECISION_mean,RECALL_class_1,RECALL_class_2,RECALL_mean,IOU_class_1,IOU_class_2,IOU_mean,SPECIFICITY_class_1,SPECIFICITY_class_2,SPECIFICITY_mean,HD95_class_1,HD95_class_2,HD95_mean,LESION_LESION_SENSITIVITY_class_0,LESION_LESION_PRECISION_class_0,LESION_LESION_F1_class_0,LESION_N_GT_LESIONS_class_0,LESION_N_PRED_LESIONS_class_0,LESION_TP_LESIONS_class_0,LESION_FN_LESIONS_class_0,LESION_FP_LESIONS_class_0,LESION_LESION_SENSITIVITY_class_1,LESION_LESION_PRECISION_class_1,LESION_LESION_F1_class_1,LESION_N_GT_LESIONS_class_1,LESION_N_PRED_LESIONS_class_1,LESION_TP_LESIONS_class_1,LESION_FN_LESIONS_class_1,LESION_FP_LESIONS_class_1,LESION_LESION_SENSITIVITY_mean,LESION_LESION_PRECISION_mean,LESION_LESION_F1_mean,LESION_N_GT_LESIONS_total,LESION_N_PRED_LESIONS_total,LESION_TP_LESIONS_total,LESION_FN_LESIONS_total,LESION_FP_LESIONS_total
|
| 2 |
+
1,unet,0,70,0.924531397,0.843338613,0.883935005,0.929979265,0.882387206,0.906183236,0.919827829,0.812886653,0.866357241,0.860059739,0.730685865,0.795372802,0.999113214,0.999260555,0.999186885,1,5.669535449,3.334767724,,,,,,,,,,,,,,,,,0.810285987,0.717464393,0.753906356,275,309,226,49,84
|
| 3 |
+
1,unet,1,70,0.931030738,0.850885379,0.890958059,0.940815576,0.886398241,0.913606909,0.921931425,0.821858445,0.871894935,0.87127183,0.742075089,0.806673459,0.999243989,0.999293762,0.999268875,1,3.105706418,2.052853209,,,,,,,,,,,,,,,,,0.831870007,0.757975006,0.788804562,275,308,230,45,76
|
| 4 |
+
1,unet,2,70,0.932586962,0.85627492,0.894430941,0.939876251,0.893898408,0.916887329,0.925870883,0.824199123,0.875035003,0.873940237,0.750010726,0.811975482,0.999235735,0.999297628,0.999266682,1,4.274767062,2.637383531,,,,,,,,,,,,,,,,,0.8190735,0.761199505,0.785236019,275,299,227,48,69
|
| 5 |
+
1,unet,3,70,0.930374301,0.838005564,0.884189933,0.940570217,0.881982903,0.91127656,0.921092709,0.803821901,0.862457305,0.870086726,0.722712824,0.796399775,0.999231167,0.999284157,0.999257662,1,5.267095904,3.133547952,,,,,,,,,,,,,,,,,0.803511136,0.755088192,0.768750434,275,312,221,54,92
|
| 6 |
+
2,attnunet,0,70,0.927807826,0.84530857,0.886558198,0.933405023,0.891192522,0.912298773,0.92300757,0.80690078,0.864954175,0.86572883,0.733543031,0.799635931,0.9991422,0.99933414,0.99923817,1,6.817768379,3.90888419,,,,,,,,,,,,,,,,,0.805918858,0.737682024,0.763465574,275,311,224,51,81
|
| 7 |
+
2,attnunet,1,70,0.921130442,0.827778912,0.874454677,0.935079326,0.882534684,0.908807005,0.908634865,0.784884088,0.846759477,0.854381693,0.707421007,0.78090135,0.999170654,0.999314091,0.999242373,1,4.299965705,2.649982853,,,,,,,,,,,,,,,,,0.797643669,0.741290568,0.757996731,275,306,221,54,83
|
| 8 |
+
2,attnunet,2,70,0.925456985,0.832425877,0.878941431,0.93222968,0.903420396,0.917825038,0.919811356,0.776171402,0.847991379,0.861589705,0.714097406,0.787843556,0.999096918,0.999459102,0.99927801,1,5.229465099,3.114732549,,,,,,,,,,,,,,,,,0.800801564,0.781510442,0.783598099,275,291,222,53,66
|
| 9 |
+
2,attnunet,3,70,0.867560953,0.80238965,0.834975301,0.880663336,0.877374368,0.879018852,0.856234009,0.751981835,0.804107922,0.767030381,0.671503295,0.719266838,0.998508343,0.999362192,0.998935268,1.913197151,7.12564331,4.51942023,,,,,,,,,,,,,,,,,0.800023323,0.639032382,0.699206596,275,339,222,53,112
|
| 10 |
+
3,dlv3unet,0,70,0.900234027,0.780845085,0.840539556,0.896169701,0.794573359,0.84537153,0.904815955,0.772086706,0.838451331,0.819275677,0.642953844,0.731114761,0.9987224,0.998555632,0.998639016,1,5.423612659,3.21180633,,,,,,,,,,,,,,,,,0.753118133,0.708678143,0.719359836,275,287,209,66,85
|
| 11 |
+
3,dlv3unet,1,70,0.903263755,0.776870874,0.840067314,0.90318543,0.812752662,0.857969046,0.90365004,0.748937606,0.826293823,0.824368179,0.63776127,0.731064725,0.998827037,0.998762871,0.998794954,1,4.251005732,2.625502866,,,,,,,,,,,,,,,,,0.730627383,0.773161114,0.746853388,275,257,200,75,65
|
| 12 |
+
3,dlv3unet,2,70,0.897481647,0.77235212,0.834916884,0.898679638,0.790748363,0.844714001,0.896838192,0.760167666,0.828502929,0.814705074,0.631956832,0.723330953,0.998738575,0.998560615,0.998649595,1,5.08065867,3.040329335,,,,,,,,,,,,,,,,,0.710273969,0.713006921,0.702347611,275,273,196,79,86
|
| 13 |
+
3,dlv3unet,3,70,0.901285368,0.775199414,0.838242391,0.901044527,0.775105014,0.838074771,0.902041939,0.780304591,0.841173265,0.820925881,0.635412356,0.728169118,0.998768492,0.998388255,0.998578374,1,4.095494924,2.547747462,,,,,,,,,,,,,,,,,0.686803058,0.707394828,0.686217762,275,264,189,86,84
|
| 14 |
+
4,transunet,0,70,0.928372145,0.841728286,0.885050215,0.948178381,0.902428523,0.925303452,0.910090008,0.792625451,0.85135773,0.866710076,0.728115177,0.797412627,0.999358865,0.999433962,0.999396413,1,8.744998559,4.872499279,,,,,,,,,,,,,,,,,0.791318113,0.78846014,0.778279842,275,299,224,51,71
|
| 15 |
+
4,transunet,1,70,0.928437459,0.837057166,0.882747312,0.935028265,0.907471981,0.921250123,0.9224873,0.779866141,0.851176721,0.866723835,0.721270863,0.793997349,0.999162149,0.99945489,0.999308519,1,6.680013727,3.840006863,,,,,,,,,,,,,,,,,0.769424443,0.762231589,0.753840342,275,297,214,61,74
|
| 16 |
+
4,transunet,2,70,0.924795158,0.836490964,0.880643061,0.928976828,0.90522043,0.917098629,0.921395395,0.780581474,0.850988434,0.860372024,0.720189454,0.790280739,0.999106986,0.999466578,0.999286782,1,4.026591793,2.513295897,,,,,,,,,,,,,,,,,0.768415042,0.766517149,0.760036199,275,282,215,60,64
|
| 17 |
+
4,transunet,3,70,0.917144393,0.813891221,0.865517807,0.915840511,0.898955881,0.907398196,0.919883748,0.751553436,0.835718592,0.847458046,0.687543773,0.767500909,0.998858155,0.999475554,0.999166854,1,4.265120809,2.632560404,,,,,,,,,,,,,,,,,0.810819711,0.695502554,0.740358673,275,330,225,50,98
|
| 18 |
+
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
|
| 19 |
+
,unet - mean,,,0.9296,0.8471,0.8884,0.9378,0.8862,0.912,0.9222,0.8157,0.8689,0.8688,0.7364,0.8026,0.9992,0.9993,0.9992,1,4.6,2.8,,,,,,,,,,,,,,,,,0.8162,0.7479,0.7742,275,307,226,49,80.25
|
| 20 |
+
,attn - mean,,,0.9105,0.827,0.8687,0.9203,0.8886,0.9045,0.9019,0.78,0.841,0.8372,0.7066,0.7719,0.999,0.9994,0.9992,1.2,5.9,3.5,,,,,,,,,,,,,,,,,0.8011,0.7249,0.7511,275,311.75,222.25,52.75,85.5
|
| 21 |
+
,dlv3 - mean,,,0.9006,0.7763,0.8384,0.8998,0.7933,0.8465,0.9018,0.7654,0.8336,0.8198,0.637,0.7284,0.9988,0.9986,0.9987,1,4.7,2.9,,,,,,,,,,,,,,,,,0.7202,0.7256,0.7137,275,270.25,198.5,76.5,80
|
| 22 |
+
,trans - mean,,,0.9247,0.8323,0.8785,0.932,0.9035,0.9178,0.9185,0.7762,0.8473,0.8603,0.7143,0.7873,0.9991,0.9995,0.9993,1,5.9,3.5,,,,,,,,,,,,,,,,,0.785,0.7532,0.7581,275,302,219.5,55.5,76.75
|
| 23 |
+
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
|
| 24 |
+
,unet - std,,,0.0031,0.007,0.0045,0.0045,0.0048,0.0039,0.0023,0.008,0.0049,0.0053,0.0105,0.007,0.0001,0,0,0,1,0.5,,,,,,,,,,,,,,,,,0.0106,0.0177,0.0139,0,4.8477,3.2404,3.2404,8.6132
|
| 25 |
+
,attn - std,,,0.0249,0.0156,0.02,0.0229,0.0099,0.0151,0.0269,0.0197,0.0225,0.0407,0.0224,0.0311,0.0003,0.0001,0.0001,0.4,1.2,0.7,,,,,,,,,,,,,,,,,0.003,0.0525,0.0314,0,17.3692,1.0897,1.0897,16.6508
|
| 26 |
+
,dlv3 - std,,,0.0021,0.0031,0.0022,0.0026,0.0134,0.0072,0.003,0.0119,0.0063,0.0035,0.004,0.0032,0,0.0001,0.0001,0,0.6,0.3,,,,,,,,,,,,,,,,,0.0245,0.0276,0.0224,0,11.211,7.2284,7.2284,8.6891
|
| 27 |
+
,trans - std,,,0.0046,0.0108,0.0076,0.0116,0.0032,0.0066,0.0049,0.0151,0.0067,0.0079,0.0157,0.0117,0.0002,0,0.0001,0,1.9,1,,,,,,,,,,,,,,,,,0.0175,0.0348,0.0136,0,17.4499,5.0249,5.0249,12.794
|
models/for_WMH_Vent/folds_results_zscore2_all/training_info_all_variants_folds.csv
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Variant,Variant_Name,Fold,Best_Epoch,Composite_Score,Total_Epochs,First_Valid_Epoch,Total_Valid_Epochs,Best_Epoch_Val_Loss,Best_Epoch_Dice_Ventricles,Best_Epoch_Dice_Abnormal_WMH,Best_Epoch_Dice_Mean,Best_Abnormal_Epoch,Best_Abnormal_Dice,Best_Ventricles_Epoch,Best_Ventricles_Dice
|
| 2 |
+
1,unet,0,49,0.837773437480731,60,1,60,0.24741753935813904,0.9308493801953847,0.8054339622632157,0.9115566193830764,43,0.8058087693216404,49,0.9308493801953847
|
| 3 |
+
1,unet,1,45,0.8509202240606865,60,1,60,0.3080134391784668,0.9268369262837436,0.8394508168223501,0.9215010850395001,28,0.8441074580031014,38,0.9274915960857308
|
| 4 |
+
1,unet,2,36,0.8128944361644407,60,1,60,0.27736401557922363,0.9342240045327603,0.7672727272708917,0.9000714500704411,32,0.7696575927137208,34,0.9378331718769447
|
| 5 |
+
1,unet,3,41,0.8148548201069025,60,1,60,0.3056482672691345,0.9412208603997376,0.7717556478564912,0.9037425888471717,41,0.7717556478564912,44,0.9415513142951589
|
| 6 |
+
2,attnunet,0,38,0.8465806985395226,60,1,60,0.2354777455329895,0.9361820594989245,0.8154564254052402,0.9167136699540254,38,0.8154564254052402,49,0.9369088654755128
|
| 7 |
+
2,attnunet,1,42,0.8468065449382642,60,1,60,0.3282952904701233,0.9189075870475,0.8399396631183776,0.9189869898404228,42,0.8399396631183776,42,0.9189075870475
|
| 8 |
+
2,attnunet,2,35,0.8082210232243792,60,1,60,0.2833690643310547,0.9301114433264854,0.7625408277658984,0.8971071685730373,35,0.7625408277658984,38,0.932546742487403
|
| 9 |
+
2,attnunet,3,35,0.7675559444491301,60,1,60,0.3719373941421509,0.8997412800024551,0.7247121664376812,0.8740189336882455,35,0.7247121664376812,51,0.9082138618936411
|
| 10 |
+
3,dlv3unet,0,41,0.7945477803722963,60,1,60,0.3116353750228882,0.8988588122221663,0.7600894570132255,0.8856122734890453,41,0.7600894570132255,54,0.9004052827384709
|
| 11 |
+
3,dlv3unet,1,42,0.8150221762616997,60,1,60,0.3728603720664978,0.9037163637265287,0.8019888405839849,0.9011564065443626,42,0.8019888405839849,40,0.9049275398385249
|
| 12 |
+
3,dlv3unet,2,28,0.7727672322932403,60,1,60,0.34316256642341614,0.9029322657428571,0.7270063486878747,0.8760687237787342,34,0.7281193622294404,40,0.9059795923856953
|
| 13 |
+
3,dlv3unet,3,28,0.768303148626621,60,1,60,0.3877088725566864,0.9124585568837787,0.7222274480285933,0.8774333183819368,28,0.7222274480285933,37,0.916682545438336
|
| 14 |
+
4,transunet,0,39,0.8410510311241638,60,1,60,0.24608999490737915,0.9346885813142664,0.808755760367703,0.9139696605000669,35,0.8090334741168289,45,0.9357508099451346
|
| 15 |
+
4,transunet,1,48,0.8483119522767122,60,1,60,0.3149019777774811,0.9253444084272909,0.8369980458771218,0.9201816141429934,29,0.8392693984451316,50,0.9254004460337096
|
| 16 |
+
4,transunet,2,35,0.8109694312756469,60,1,60,0.27634185552597046,0.9331867846270409,0.7644126357335528,0.89876925662194,35,0.7644126357335528,55,0.9336813537844856
|
| 17 |
+
4,transunet,3,28,0.7773769906034641,60,1,60,0.34197694063186646,0.9332184349847285,0.7193485902853871,0.883513943096132,28,0.7193485902853871,50,0.9413152166406702
|
models/for_WMH_Vent/folds_results_zscore2_all/variant_comparison_test.csv
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Variant,Variant_Name,N_Folds,DICE_Mean,DICE_Std,DICE_Class1_Mean,DICE_Class1_Std,DICE_Class2_Mean,DICE_Class2_Std,PRECISION_Mean,PRECISION_Std,PRECISION_Class1_Mean,PRECISION_Class1_Std,PRECISION_Class2_Mean,PRECISION_Class2_Std,RECALL_Mean,RECALL_Std,RECALL_Class1_Mean,RECALL_Class1_Std,RECALL_Class2_Mean,RECALL_Class2_Std,IOU_Mean,IOU_Std,IOU_Class1_Mean,IOU_Class1_Std,IOU_Class2_Mean,IOU_Class2_Std,SPECIFICITY_Mean,SPECIFICITY_Std,SPECIFICITY_Class1_Mean,SPECIFICITY_Class1_Std,SPECIFICITY_Class2_Mean,SPECIFICITY_Class2_Std,HD95_Mean,HD95_Std,HD95_Class1_Mean,HD95_Class1_Std,HD95_Class2_Mean,HD95_Class2_Std,LESION_SENSITIVITY_Mean,LESION_SENSITIVITY_Std,LESION_PRECISION_Mean,LESION_PRECISION_Std,LESION_F1_Mean,LESION_F1_Std,LESION_N_GT_LESIONS_Total,LESION_N_PRED_LESIONS_Total,LESION_TP_LESIONS_Total,LESION_FN_LESIONS_Total,LESION_FP_LESIONS_Total
|
| 2 |
+
1,unet,4,0.88837848442577,0.004488176438408171,0.9296308495604303,0.003051861083997252,0.8471261192911104,0.006988603634009174,0.9119885083644996,0.003899542620426997,0.937810327296536,0.004534371323946414,0.8861666894324636,0.004785959852918547,0.8689361210113892,0.004862525795740292,0.9221807114485115,0.002258280011483868,0.8156915305742668,0.008049307835491817,0.8026053795053394,0.006985123311053306,0.86883963293893,0.005257697310767231,0.7363711260717487,0.01045759631012397,0.9992450258555589,3.3829763475009536e-05,0.9992060262932462,5.3781685628696937e-05,0.9992840254178714,1.4410851399106235e-05,2.789638104058208,0.4953467782180657,1.0,0.0,4.579276208116416,0.9906935564361317,0.8161851575612329,0.010604103780010409,0.7479317737104494,0.01772268907341698,0.7741743428489778,0.01393389833458131,1100,1228,904,196,321
|
| 3 |
+
2,attnunet,4,0.8687324016526609,0.019964152607882032,0.9104890513851166,0.024899999222747722,0.826975751920205,0.015579775036519973,0.9044874168421302,0.015051711319037623,0.9203443411150141,0.02293220806285698,0.8886304925692461,0.009863625617703263,0.8409532379708647,0.02245478860306288,0.9019219497921485,0.026912289770452267,0.779984526149581,0.01966711140506486,0.7719119185728298,0.031123754171912342,0.83718265247343,0.040705456147877725,0.7066411846722295,0.022442058844092443,0.9991734550617465,0.000138385919780625,0.9989795287466985,0.000273307129456502,0.9993673813767945,5.564253319795054e-05,3.5482549554417186,0.7190357923335028,1.2282992876459566,0.3954259655345788,5.868210623237481,1.1565233310098124,0.8010968535966527,0.0030172783603819066,0.7248788539521523,0.05246431861595214,0.7510667501711839,0.0314226022122775,1100,1247,889,211,342
|
| 4 |
+
3,dlv3unet,4,0.8384415363144644,0.0022083747230999306,0.9005661992435416,0.0020867923289419102,0.7763168733853871,0.003073255677872925,0.8465323369072222,0.0071934443649559685,0.8997698242284116,0.002619190013214462,0.7932948495860329,0.01339398897949029,0.8336053369013228,0.0063294943822809775,0.9018365317555639,0.00304923631516394,0.7653741420470819,0.01188742959950416,0.7284198893837046,0.003170869073367662,0.8198187029362412,0.0034773428576675017,0.6370210758311682,0.003999380279290716,0.9986654847223488,7.953573517565171e-05,0.9987641261362938,3.990839508953154e-05,0.9985668433084038,0.00013274507182798207,2.8563464981341697,0.27781447224790423,1.0,0.0,4.7126929962683395,0.5556289444958085,0.720205635740483,0.024526594995177977,0.7255602516176558,0.02756091610984561,0.7136946491830041,0.022446218734487867,1100,1081,794,306,320
|
| 5 |
+
4,transunet,4,0.8784895989143288,0.007649748417636039,0.9246872887842248,0.004597522753464204,0.8322919090444327,0.010816310171427137,0.9177625999227637,0.00664998101397249,0.9320059959760637,0.011631626135529186,0.9035192038694635,0.003183241810633453,0.8473103690612981,0.006693789315880423,0.9184641125298365,0.004922383784531681,0.7761566255927599,0.015082099936625985,0.7872979059128824,0.011704790584520298,0.8603159951545595,0.007862485177144832,0.7142798166712054,0.01573238533985691,0.9992896422936255,8.191687124816227e-05,0.9991215386213639,0.00017857472387549319,0.9994577459658872,1.5563988190790074e-05,3.4645906109504088,0.964414340383405,1.0,0.0,5.929181221900818,1.9288286807668098,0.784994327470758,0.01749453278425242,0.7531778580152171,0.03475336972283068,0.7581287639607504,0.013636999062794438,1100,1208,878,222,307
|
models/for_WMH_Vent/folds_results_zscore2_all/variant_comparison_training.csv
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Variant,Variant_Name,N_Folds,Best_Epoch_Mean,Best_Epoch_Std,Best_Epoch_Min,Best_Epoch_Max,Composite_Score_Mean,Composite_Score_Std,Best_Epoch_Val_Loss_Mean,Best_Epoch_Val_Loss_Std,Best_Epoch_Dice_Mean_Mean,Best_Epoch_Dice_Mean_Std,Best_Epoch_Dice_Ventricles_Mean,Best_Epoch_Dice_Ventricles_Std,Best_Epoch_Dice_Abnormal_WMH_Mean,Best_Epoch_Dice_Abnormal_WMH_Std
|
| 2 |
+
1,unet,4,42.75,4.815340071064556,36,49,0.8291107294531901,0.0159444009351335,0.284610815346241,0.02462779461924618,0.9092179358350473,0.00821557699733339,0.9332827928529066,0.005276587173391204,0.7959782885532373,0.02911192217235299
|
| 3 |
+
2,attnunet,4,37.5,2.8722813232690143,35,42,0.8172910527878241,0.03272955110785234,0.3047698736190796,0.05080431749397413,0.9017066905139327,0.018107910266314746,0.9212355924688412,0.013870856883885596,0.7856622706817994,0.044953429555710196
|
| 4 |
+
3,dlv3unet,4,34.75,6.7592529172978875,28,42,0.7876600843884642,0.018658861388516038,0.35384179651737213,0.029172388598382663,0.8850676805485197,0.009980085098274465,0.9044914996438328,0.0049556335292024285,0.7528280235784197,0.03190873102426834
|
| 5 |
+
4,transunet,4,37.5,7.22841614740048,28,48,0.8194273513199967,0.028025314220223092,0.2948276922106743,0.0365483168698032,0.904108618590283,0.01421469994866525,0.9316095523383316,0.0036677177073460732,0.7823787580659411,0.0446502980212195
|
models/for_WMH_Vent/model_training_scripts/attn_unet_model.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
###################### Libraries ######################
|
| 2 |
+
# Deep Learning
|
| 3 |
+
import keras
|
| 4 |
+
from keras.models import Model
|
| 5 |
+
from keras.layers import Input, Conv2D, MaxPooling2D, Conv2DTranspose, concatenate
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def build_attention_unet_3class(input_shape=(256, 256, 1), num_classes=3):
|
| 9 |
+
"""Enhanced Attention U-Net architecture with dropout"""
|
| 10 |
+
|
| 11 |
+
def attention_block(F_g, F_l, F_int):
|
| 12 |
+
"""Attention gate implementation"""
|
| 13 |
+
W_g = Conv2D(F_int, 1, padding='same')(F_g)
|
| 14 |
+
W_x = Conv2D(F_int, 1, padding='same')(F_l)
|
| 15 |
+
psi = keras.layers.Add()([W_g, W_x])
|
| 16 |
+
psi = keras.layers.Activation('relu')(psi)
|
| 17 |
+
psi = Conv2D(1, 1, padding='same')(psi)
|
| 18 |
+
psi = keras.layers.Activation('sigmoid')(psi)
|
| 19 |
+
return keras.layers.Multiply()([F_l, psi])
|
| 20 |
+
|
| 21 |
+
inputs = Input(input_shape)
|
| 22 |
+
|
| 23 |
+
# Encoder with dropout (matching your original dropout pattern)
|
| 24 |
+
c1 = Conv2D(64, 3, activation='relu', padding='same')(inputs)
|
| 25 |
+
c1 = Conv2D(64, 3, activation='relu', padding='same')(c1)
|
| 26 |
+
p1 = MaxPooling2D(2)(c1)
|
| 27 |
+
p1 = keras.layers.Dropout(0.1)(p1)
|
| 28 |
+
|
| 29 |
+
c2 = Conv2D(128, 3, activation='relu', padding='same')(p1)
|
| 30 |
+
c2 = Conv2D(128, 3, activation='relu', padding='same')(c2)
|
| 31 |
+
p2 = MaxPooling2D(2)(c2)
|
| 32 |
+
p2 = keras.layers.Dropout(0.1)(p2)
|
| 33 |
+
|
| 34 |
+
c3 = Conv2D(256, 3, activation='relu', padding='same')(p2)
|
| 35 |
+
c3 = Conv2D(256, 3, activation='relu', padding='same')(c3)
|
| 36 |
+
p3 = MaxPooling2D(2)(c3)
|
| 37 |
+
p3 = keras.layers.Dropout(0.2)(p3)
|
| 38 |
+
|
| 39 |
+
c4 = Conv2D(512, 3, activation='relu', padding='same')(p3)
|
| 40 |
+
c4 = Conv2D(512, 3, activation='relu', padding='same')(c4)
|
| 41 |
+
p4 = MaxPooling2D(2)(c4)
|
| 42 |
+
p4 = keras.layers.Dropout(0.2)(p4)
|
| 43 |
+
|
| 44 |
+
# Bridge
|
| 45 |
+
c5 = Conv2D(1024, 3, activation='relu', padding='same')(p4)
|
| 46 |
+
c5 = Conv2D(1024, 3, activation='relu', padding='same')(c5)
|
| 47 |
+
c5 = keras.layers.Dropout(0.3)(c5)
|
| 48 |
+
|
| 49 |
+
# Decoder with attention gates (using Conv2DTranspose - more standard)
|
| 50 |
+
u6 = Conv2DTranspose(512, 2, strides=2, padding='same')(c5)
|
| 51 |
+
att6 = attention_block(u6, c4, 256)
|
| 52 |
+
u6 = concatenate([u6, att6])
|
| 53 |
+
u6 = keras.layers.Dropout(0.2)(u6)
|
| 54 |
+
c6 = Conv2D(512, 3, activation='relu', padding='same')(u6)
|
| 55 |
+
c6 = Conv2D(512, 3, activation='relu', padding='same')(c6)
|
| 56 |
+
|
| 57 |
+
u7 = Conv2DTranspose(256, 2, strides=2, padding='same')(c6)
|
| 58 |
+
att7 = attention_block(u7, c3, 128)
|
| 59 |
+
u7 = concatenate([u7, att7])
|
| 60 |
+
u7 = keras.layers.Dropout(0.2)(u7)
|
| 61 |
+
c7 = Conv2D(256, 3, activation='relu', padding='same')(u7)
|
| 62 |
+
c7 = Conv2D(256, 3, activation='relu', padding='same')(c7)
|
| 63 |
+
|
| 64 |
+
u8 = Conv2DTranspose(128, 2, strides=2, padding='same')(c7)
|
| 65 |
+
att8 = attention_block(u8, c2, 64)
|
| 66 |
+
u8 = concatenate([u8, att8])
|
| 67 |
+
u8 = keras.layers.Dropout(0.1)(u8)
|
| 68 |
+
c8 = Conv2D(128, 3, activation='relu', padding='same')(u8)
|
| 69 |
+
c8 = Conv2D(128, 3, activation='relu', padding='same')(c8)
|
| 70 |
+
|
| 71 |
+
u9 = Conv2DTranspose(64, 2, strides=2, padding='same')(c8)
|
| 72 |
+
att9 = attention_block(u9, c1, 32)
|
| 73 |
+
u9 = concatenate([u9, att9])
|
| 74 |
+
u9 = keras.layers.Dropout(0.1)(u9)
|
| 75 |
+
c9 = Conv2D(64, 3, activation='relu', padding='same')(u9)
|
| 76 |
+
c9 = Conv2D(64, 3, activation='relu', padding='same')(c9)
|
| 77 |
+
|
| 78 |
+
# Output layer - preserving your original conditional logic
|
| 79 |
+
if num_classes == 1:
|
| 80 |
+
outputs = Conv2D(1, 1, activation='sigmoid')(c9)
|
| 81 |
+
else:
|
| 82 |
+
outputs = Conv2D(num_classes, 1, activation='softmax')(c9)
|
| 83 |
+
|
| 84 |
+
return Model(inputs, outputs)
|
| 85 |
+
|
models/for_WMH_Vent/model_training_scripts/base_runner_all.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import subprocess
|
| 2 |
+
import sys
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
# Run scripts one after another
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
for fold in range(4):
|
| 10 |
+
|
| 11 |
+
# Skip folds:
|
| 12 |
+
# if fold in list(np.array([0])):
|
| 13 |
+
# continue
|
| 14 |
+
|
| 15 |
+
for variant in range(5):
|
| 16 |
+
|
| 17 |
+
# Skip variants:
|
| 18 |
+
if variant in list(np.array([0])):
|
| 19 |
+
continue
|
| 20 |
+
|
| 21 |
+
# subprocess.run([sys.executable, "p4_run_experiments_all.py", "--variant", str(variant), "--fold", str(fold), "--scenario", "standard_3class"])
|
| 22 |
+
|
| 23 |
+
|
models/for_WMH_Vent/model_training_scripts/dlv3_unet_model.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
###################### Libraries ######################
|
| 2 |
+
# Deep Learning
|
| 3 |
+
import tensorflow as tf
|
| 4 |
+
import keras
|
| 5 |
+
from keras.models import Model, load_model
|
| 6 |
+
from keras.layers import Input, Conv2D, MaxPooling2D, Conv2DTranspose, concatenate
|
| 7 |
+
from keras import backend as K
|
| 8 |
+
from tensorflow.keras import layers, optimizers, callbacks
|
| 9 |
+
from keras.utils import to_categorical
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def build_deeplabv3_unet_3class(input_shape=(256, 256, 1), num_classes=3):
|
| 13 |
+
"""
|
| 14 |
+
Standard DeepLabV3+ implementation with ResNet-50 backbone
|
| 15 |
+
Following the original paper: "Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation"
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def conv_block(x, filters, kernel_size=3, strides=1, dilation_rate=1, use_bias=False, name=None):
|
| 19 |
+
"""Standard convolution block with BN and ReLU"""
|
| 20 |
+
x = layers.Conv2D(filters, kernel_size, strides=strides, padding='same',
|
| 21 |
+
dilation_rate=dilation_rate, use_bias=use_bias, name=name)(x)
|
| 22 |
+
# x = layers.BatchNormalization()(x)
|
| 23 |
+
x = layers.Activation('relu')(x)
|
| 24 |
+
return x
|
| 25 |
+
|
| 26 |
+
def bottleneck_residual_block(x, filters, strides=1, dilation_rate=1, projection_shortcut=False, name_prefix=""):
|
| 27 |
+
"""ResNet-50 bottleneck block with optional atrous convolution"""
|
| 28 |
+
shortcut = x
|
| 29 |
+
|
| 30 |
+
# Projection shortcut if needed
|
| 31 |
+
if projection_shortcut:
|
| 32 |
+
shortcut = layers.Conv2D(filters * 4, 1, strides=strides, use_bias=False,
|
| 33 |
+
name=f"{name_prefix}_0_conv")(shortcut)
|
| 34 |
+
# shortcut = layers.BatchNormalization(name=f"{name_prefix}_0_bn")(shortcut)
|
| 35 |
+
|
| 36 |
+
# Bottleneck layers
|
| 37 |
+
x = layers.Conv2D(filters, 1, use_bias=False, name=f"{name_prefix}_1_conv")(x)
|
| 38 |
+
# x = layers.BatchNormalization(name=f"{name_prefix}_1_bn")(x)
|
| 39 |
+
x = layers.Activation('relu')(x)
|
| 40 |
+
|
| 41 |
+
x = layers.Conv2D(filters, 3, strides=strides, padding='same',
|
| 42 |
+
dilation_rate=dilation_rate, use_bias=False, name=f"{name_prefix}_2_conv")(x)
|
| 43 |
+
# x = layers.BatchNormalization(name=f"{name_prefix}_2_bn")(x)
|
| 44 |
+
x = layers.Activation('relu')(x)
|
| 45 |
+
|
| 46 |
+
x = layers.Conv2D(filters * 4, 1, use_bias=False, name=f"{name_prefix}_3_conv")(x)
|
| 47 |
+
# x = layers.BatchNormalization(name=f"{name_prefix}_3_bn")(x)
|
| 48 |
+
|
| 49 |
+
x = layers.Add()([shortcut, x])
|
| 50 |
+
x = layers.Activation('relu')(x)
|
| 51 |
+
|
| 52 |
+
return x
|
| 53 |
+
|
| 54 |
+
def aspp_block(x, filters=256):
|
| 55 |
+
"""Atrous Spatial Pyramid Pooling with proper implementation"""
|
| 56 |
+
|
| 57 |
+
# ASPP branches
|
| 58 |
+
# 1x1 convolution
|
| 59 |
+
b1 = layers.Conv2D(filters, 1, use_bias=False, name='aspp_1x1')(x)
|
| 60 |
+
# b1 = layers.BatchNormalization(name='aspp_1x1_bn')(b1)
|
| 61 |
+
b1 = layers.Activation('relu')(b1)
|
| 62 |
+
|
| 63 |
+
# 3x3 convolution with rate = 6
|
| 64 |
+
b2 = layers.Conv2D(filters, 3, padding='same', dilation_rate=6, use_bias=False, name='aspp_3x3_6')(x)
|
| 65 |
+
# b2 = layers.BatchNormalization(name='aspp_3x3_6_bn')(b2)
|
| 66 |
+
b2 = layers.Activation('relu')(b2)
|
| 67 |
+
|
| 68 |
+
# 3x3 convolution with rate = 12
|
| 69 |
+
b3 = layers.Conv2D(filters, 3, padding='same', dilation_rate=12, use_bias=False, name='aspp_3x3_12')(x)
|
| 70 |
+
# b3 = layers.BatchNormalization(name='aspp_3x3_12_bn')(b3)
|
| 71 |
+
b3 = layers.Activation('relu')(b3)
|
| 72 |
+
|
| 73 |
+
# 3x3 convolution with rate = 18
|
| 74 |
+
b4 = layers.Conv2D(filters, 3, padding='same', dilation_rate=18, use_bias=False, name='aspp_3x3_18')(x)
|
| 75 |
+
# b4 = layers.BatchNormalization(name='aspp_3x3_18_bn')(b4)
|
| 76 |
+
b4 = layers.Activation('relu')(b4)
|
| 77 |
+
|
| 78 |
+
# Image-level features (Global Average Pooling) - Simplified approach
|
| 79 |
+
# Get input spatial dimensions
|
| 80 |
+
input_shape = tf.shape(x)
|
| 81 |
+
h, w = input_shape[1], input_shape[2]
|
| 82 |
+
|
| 83 |
+
b5 = layers.GlobalAveragePooling2D(name='aspp_gap')(x)
|
| 84 |
+
b5 = layers.Reshape((1, 1, -1))(b5)
|
| 85 |
+
b5 = layers.Conv2D(filters, 1, use_bias=False, name='aspp_gap_conv')(b5)
|
| 86 |
+
# b5 = layers.BatchNormalization(name='aspp_gap_bn')(b5)
|
| 87 |
+
b5 = layers.Activation('relu')(b5)
|
| 88 |
+
|
| 89 |
+
# Use a resize function that handles KerasTensors properly
|
| 90 |
+
def resize_to_input_shape(args):
|
| 91 |
+
features, spatial_shape = args
|
| 92 |
+
return tf.image.resize(features, spatial_shape, method='bilinear')
|
| 93 |
+
|
| 94 |
+
b5 = layers.Lambda(resize_to_input_shape, name='aspp_gap_resize')([b5, [h, w]])
|
| 95 |
+
|
| 96 |
+
# Concatenate all branches
|
| 97 |
+
concat_features = layers.Concatenate(name='aspp_concat')([b1, b2, b3, b4, b5])
|
| 98 |
+
|
| 99 |
+
# Final 1x1 convolution
|
| 100 |
+
output = layers.Conv2D(filters, 1, use_bias=False, name='aspp_final_conv')(concat_features)
|
| 101 |
+
# output = layers.BatchNormalization(name='aspp_final_bn')(output)
|
| 102 |
+
output = layers.Activation('relu')(output)
|
| 103 |
+
output = layers.Dropout(0.1, name='aspp_dropout')(output)
|
| 104 |
+
|
| 105 |
+
return output
|
| 106 |
+
|
| 107 |
+
# Input layer
|
| 108 |
+
inputs = layers.Input(input_shape, name='input')
|
| 109 |
+
|
| 110 |
+
# ==================== ENCODER (ResNet-50 Backbone) ====================
|
| 111 |
+
|
| 112 |
+
# Initial convolution
|
| 113 |
+
x = layers.Conv2D(64, 7, strides=2, padding='same', use_bias=False, name='conv1')(inputs)
|
| 114 |
+
# x = layers.BatchNormalization(name='conv1_bn')(x)
|
| 115 |
+
x = layers.Activation('relu')(x)
|
| 116 |
+
x = layers.MaxPooling2D(3, strides=2, padding='same', name='pool1')(x)
|
| 117 |
+
|
| 118 |
+
# Stage 1 (conv2_x) - Low-level features for decoder
|
| 119 |
+
x = bottleneck_residual_block(x, 64, strides=1, projection_shortcut=True, name_prefix='conv2_block1')
|
| 120 |
+
x = bottleneck_residual_block(x, 64, name_prefix='conv2_block2')
|
| 121 |
+
low_level_features = bottleneck_residual_block(x, 64, name_prefix='conv2_block3')
|
| 122 |
+
|
| 123 |
+
# Stage 2 (conv3_x)
|
| 124 |
+
x = bottleneck_residual_block(low_level_features, 128, strides=2, projection_shortcut=True, name_prefix='conv3_block1')
|
| 125 |
+
x = bottleneck_residual_block(x, 128, name_prefix='conv3_block2')
|
| 126 |
+
x = bottleneck_residual_block(x, 128, name_prefix='conv3_block3')
|
| 127 |
+
x = bottleneck_residual_block(x, 128, name_prefix='conv3_block4')
|
| 128 |
+
|
| 129 |
+
# Stage 3 (conv4_x) - With atrous convolution
|
| 130 |
+
x = bottleneck_residual_block(x, 256, strides=1, dilation_rate=2, projection_shortcut=True, name_prefix='conv4_block1')
|
| 131 |
+
x = bottleneck_residual_block(x, 256, dilation_rate=2, name_prefix='conv4_block2')
|
| 132 |
+
x = bottleneck_residual_block(x, 256, dilation_rate=2, name_prefix='conv4_block3')
|
| 133 |
+
x = bottleneck_residual_block(x, 256, dilation_rate=2, name_prefix='conv4_block4')
|
| 134 |
+
x = bottleneck_residual_block(x, 256, dilation_rate=2, name_prefix='conv4_block5')
|
| 135 |
+
x = bottleneck_residual_block(x, 256, dilation_rate=2, name_prefix='conv4_block6')
|
| 136 |
+
|
| 137 |
+
# Stage 4 (conv5_x) - With higher atrous rate
|
| 138 |
+
x = bottleneck_residual_block(x, 512, strides=1, dilation_rate=4, projection_shortcut=True, name_prefix='conv5_block1')
|
| 139 |
+
x = bottleneck_residual_block(x, 512, dilation_rate=4, name_prefix='conv5_block2')
|
| 140 |
+
x = bottleneck_residual_block(x, 512, dilation_rate=4, name_prefix='conv5_block3')
|
| 141 |
+
|
| 142 |
+
# ==================== ASPP MODULE ====================
|
| 143 |
+
x = aspp_block(x, filters=256)
|
| 144 |
+
|
| 145 |
+
# ==================== DECODER ====================
|
| 146 |
+
|
| 147 |
+
# Use fixed upsampling - the spatial relationship should be predictable
|
| 148 |
+
# ASPP output is at 1/16 resolution, low_level_features at 1/4 resolution
|
| 149 |
+
# So we need 4x upsampling to match
|
| 150 |
+
x = layers.UpSampling2D(size=(4, 4), interpolation='bilinear', name='decoder_upsample1')(x)
|
| 151 |
+
|
| 152 |
+
# Process low-level features
|
| 153 |
+
low_level_features = layers.Conv2D(48, 1, use_bias=False, name='decoder_low_level_conv')(low_level_features)
|
| 154 |
+
# low_level_features = layers.BatchNormalization(name='decoder_low_level_bn')(low_level_features)
|
| 155 |
+
low_level_features = layers.Activation('relu')(low_level_features)
|
| 156 |
+
|
| 157 |
+
# If there's still a size mismatch, crop or pad to match
|
| 158 |
+
def match_spatial_dims(tensors):
|
| 159 |
+
high_level, low_level = tensors
|
| 160 |
+
# Get shapes
|
| 161 |
+
high_shape = tf.shape(high_level)
|
| 162 |
+
low_shape = tf.shape(low_level)
|
| 163 |
+
|
| 164 |
+
# Crop high_level to match low_level if it's larger
|
| 165 |
+
high_level_matched = high_level[:, :low_shape[1], :low_shape[2], :]
|
| 166 |
+
return high_level_matched, low_level
|
| 167 |
+
|
| 168 |
+
x_matched, low_level_matched = layers.Lambda(match_spatial_dims, name='match_dims')([x, low_level_features])
|
| 169 |
+
|
| 170 |
+
# Concatenate high-level and low-level features
|
| 171 |
+
x = layers.Concatenate(name='decoder_concat')([x_matched, low_level_matched])
|
| 172 |
+
|
| 173 |
+
# Refine features
|
| 174 |
+
x = layers.Conv2D(256, 3, padding='same', use_bias=False, name='decoder_conv1')(x)
|
| 175 |
+
# x = layers.BatchNormalization(name='decoder_conv1_bn')(x)
|
| 176 |
+
x = layers.Activation('relu')(x)
|
| 177 |
+
x = layers.Dropout(0.1, name='decoder_dropout1')(x) # Light regularization
|
| 178 |
+
|
| 179 |
+
x = layers.Conv2D(256, 3, padding='same', use_bias=False, name='decoder_conv2')(x)
|
| 180 |
+
# x = layers.BatchNormalization(name='decoder_conv2_bn')(x)
|
| 181 |
+
x = layers.Activation('relu')(x)
|
| 182 |
+
x = layers.Dropout(0.1, name='decoder_dropout2')(x)
|
| 183 |
+
|
| 184 |
+
# Final upsampling to original resolution (4x upsampling)
|
| 185 |
+
x = layers.UpSampling2D(size=(4, 4), interpolation='bilinear', name='decoder_upsample2')(x)
|
| 186 |
+
|
| 187 |
+
# ==================== OUTPUT ====================
|
| 188 |
+
|
| 189 |
+
# Output layer - preserving your original conditional logic
|
| 190 |
+
if num_classes == 1:
|
| 191 |
+
outputs = layers.Conv2D(1, 1, activation='sigmoid', name='output')(x)
|
| 192 |
+
else:
|
| 193 |
+
outputs = layers.Conv2D(num_classes, 1, activation='softmax', name='output')(x)
|
| 194 |
+
|
| 195 |
+
# Create model
|
| 196 |
+
model = keras.Model(inputs, outputs, name='DeepLabV3Plus_ResNet50')
|
| 197 |
+
|
| 198 |
+
return model
|
models/for_WMH_Vent/model_training_scripts/dlv3_unet_model_GN.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
###################### Libraries ######################
|
| 2 |
+
# Deep Learning
|
| 3 |
+
import tensorflow as tf
|
| 4 |
+
import keras
|
| 5 |
+
from keras.models import Model, load_model
|
| 6 |
+
from keras.layers import Input, Conv2D, MaxPooling2D, Conv2DTranspose, concatenate
|
| 7 |
+
from keras import backend as K
|
| 8 |
+
from tensorflow.keras import layers, optimizers, callbacks
|
| 9 |
+
from keras.utils import to_categorical
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def build_deeplabv3_unet_3class(input_shape=(256, 256, 1), num_classes=3):
|
| 13 |
+
"""
|
| 14 |
+
DeepLabV3+ with ResNet-50 backbone.
|
| 15 |
+
|
| 16 |
+
Key fix over the original:
|
| 17 |
+
- All BatchNormalization replaced with GroupNormalization (groups=8).
|
| 18 |
+
GroupNorm is batch-size independent, so inference statistics are
|
| 19 |
+
identical whether training=True or training=False — no more need to
|
| 20 |
+
force training=True at inference time.
|
| 21 |
+
|
| 22 |
+
Input: single-channel (grayscale) MRI images → (H, W, 1)
|
| 23 |
+
Output: per-pixel class probabilities → (H, W, num_classes)
|
| 24 |
+
or binary mask → (H, W, 1) when num_classes==1
|
| 25 |
+
|
| 26 |
+
Reference:
|
| 27 |
+
"Encoder-Decoder with Atrous Separable Convolution for
|
| 28 |
+
Semantic Image Segmentation", Chen et al. 2018.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
# ------------------------------------------------------------------
|
| 32 |
+
# Helper: GroupNorm drop-in for BatchNorm
|
| 33 |
+
# groups=8 works well for filter counts ≥ 32 that are multiples of 8.
|
| 34 |
+
# ------------------------------------------------------------------
|
| 35 |
+
def group_norm(name=None):
|
| 36 |
+
return layers.GroupNormalization(groups=4, name=name)
|
| 37 |
+
|
| 38 |
+
# ------------------------------------------------------------------
|
| 39 |
+
def conv_block(x, filters, kernel_size=3, strides=1,
|
| 40 |
+
dilation_rate=1, use_bias=False, name=None):
|
| 41 |
+
"""Standard convolution block with GroupNorm and ReLU."""
|
| 42 |
+
x = layers.Conv2D(
|
| 43 |
+
filters, kernel_size, strides=strides, padding='same',
|
| 44 |
+
dilation_rate=dilation_rate, use_bias=use_bias, name=name
|
| 45 |
+
)(x)
|
| 46 |
+
x = group_norm()(x)
|
| 47 |
+
x = layers.Activation('relu')(x)
|
| 48 |
+
return x
|
| 49 |
+
|
| 50 |
+
# ------------------------------------------------------------------
|
| 51 |
+
def bottleneck_residual_block(x, filters, strides=1, dilation_rate=1,
|
| 52 |
+
projection_shortcut=False, name_prefix=""):
|
| 53 |
+
"""ResNet-50 bottleneck block with optional atrous convolution."""
|
| 54 |
+
shortcut = x
|
| 55 |
+
|
| 56 |
+
# Projection shortcut if dimensions change
|
| 57 |
+
if projection_shortcut:
|
| 58 |
+
shortcut = layers.Conv2D(
|
| 59 |
+
filters * 4, 1, strides=strides, use_bias=False,
|
| 60 |
+
name=f"{name_prefix}_0_conv"
|
| 61 |
+
)(shortcut)
|
| 62 |
+
shortcut = group_norm(name=f"{name_prefix}_0_gn")(shortcut)
|
| 63 |
+
|
| 64 |
+
# 1×1 → 3×3 (possibly atrous) → 1×1 bottleneck
|
| 65 |
+
x = layers.Conv2D(filters, 1, use_bias=False,
|
| 66 |
+
name=f"{name_prefix}_1_conv")(x)
|
| 67 |
+
x = group_norm(name=f"{name_prefix}_1_gn")(x)
|
| 68 |
+
x = layers.Activation('relu')(x)
|
| 69 |
+
|
| 70 |
+
x = layers.Conv2D(
|
| 71 |
+
filters, 3, strides=strides, padding='same',
|
| 72 |
+
dilation_rate=dilation_rate, use_bias=False,
|
| 73 |
+
name=f"{name_prefix}_2_conv"
|
| 74 |
+
)(x)
|
| 75 |
+
x = group_norm(name=f"{name_prefix}_2_gn")(x)
|
| 76 |
+
x = layers.Activation('relu')(x)
|
| 77 |
+
|
| 78 |
+
x = layers.Conv2D(filters * 4, 1, use_bias=False,
|
| 79 |
+
name=f"{name_prefix}_3_conv")(x)
|
| 80 |
+
x = group_norm(name=f"{name_prefix}_3_gn")(x)
|
| 81 |
+
|
| 82 |
+
x = layers.Add()([shortcut, x])
|
| 83 |
+
x = layers.Activation('relu')(x)
|
| 84 |
+
return x
|
| 85 |
+
|
| 86 |
+
# ------------------------------------------------------------------
|
| 87 |
+
def aspp_block(x, filters=256):
|
| 88 |
+
"""Atrous Spatial Pyramid Pooling."""
|
| 89 |
+
|
| 90 |
+
# Branch 1 — 1×1 conv
|
| 91 |
+
b1 = layers.Conv2D(filters, 1, use_bias=False, name='aspp_1x1')(x)
|
| 92 |
+
b1 = group_norm(name='aspp_1x1_gn')(b1)
|
| 93 |
+
b1 = layers.Activation('relu')(b1)
|
| 94 |
+
|
| 95 |
+
# Branch 2 — 3×3, rate=6
|
| 96 |
+
b2 = layers.Conv2D(filters, 3, padding='same', dilation_rate=6,
|
| 97 |
+
use_bias=False, name='aspp_3x3_6')(x)
|
| 98 |
+
b2 = group_norm(name='aspp_3x3_6_gn')(b2)
|
| 99 |
+
b2 = layers.Activation('relu')(b2)
|
| 100 |
+
|
| 101 |
+
# Branch 3 — 3×3, rate=12
|
| 102 |
+
b3 = layers.Conv2D(filters, 3, padding='same', dilation_rate=12,
|
| 103 |
+
use_bias=False, name='aspp_3x3_12')(x)
|
| 104 |
+
b3 = group_norm(name='aspp_3x3_12_gn')(b3)
|
| 105 |
+
b3 = layers.Activation('relu')(b3)
|
| 106 |
+
|
| 107 |
+
# Branch 4 — 3×3, rate=18
|
| 108 |
+
b4 = layers.Conv2D(filters, 3, padding='same', dilation_rate=18,
|
| 109 |
+
use_bias=False, name='aspp_3x3_18')(x)
|
| 110 |
+
b4 = group_norm(name='aspp_3x3_18_gn')(b4)
|
| 111 |
+
b4 = layers.Activation('relu')(b4)
|
| 112 |
+
|
| 113 |
+
# Branch 5 — image-level global context via GAP + resize
|
| 114 |
+
input_shape_dyn = tf.shape(x)
|
| 115 |
+
h, w = input_shape_dyn[1], input_shape_dyn[2]
|
| 116 |
+
|
| 117 |
+
b5 = layers.GlobalAveragePooling2D(name='aspp_gap')(x)
|
| 118 |
+
b5 = layers.Reshape((1, 1, -1))(b5)
|
| 119 |
+
b5 = layers.Conv2D(filters, 1, use_bias=False,
|
| 120 |
+
name='aspp_gap_conv')(b5)
|
| 121 |
+
b5 = group_norm(name='aspp_gap_gn')(b5)
|
| 122 |
+
b5 = layers.Activation('relu')(b5)
|
| 123 |
+
b5 = layers.Lambda(
|
| 124 |
+
lambda args: tf.image.resize(args[0], args[1], method='bilinear'),
|
| 125 |
+
name='aspp_gap_resize'
|
| 126 |
+
)([b5, [h, w]])
|
| 127 |
+
|
| 128 |
+
# Fuse all branches
|
| 129 |
+
concat = layers.Concatenate(name='aspp_concat')([b1, b2, b3, b4, b5])
|
| 130 |
+
out = layers.Conv2D(filters, 1, use_bias=False,
|
| 131 |
+
name='aspp_final_conv')(concat)
|
| 132 |
+
out = group_norm(name='aspp_final_gn')(out)
|
| 133 |
+
out = layers.Activation('relu')(out)
|
| 134 |
+
out = layers.Dropout(0.1, name='aspp_dropout')(out)
|
| 135 |
+
return out
|
| 136 |
+
|
| 137 |
+
# ==================================================================
|
| 138 |
+
# INPUT — grayscale, single channel
|
| 139 |
+
# ==================================================================
|
| 140 |
+
inputs = layers.Input(input_shape, name='input') # (H, W, 1)
|
| 141 |
+
|
| 142 |
+
# ==================================================================
|
| 143 |
+
# ENCODER — ResNet-50 backbone
|
| 144 |
+
# ==================================================================
|
| 145 |
+
|
| 146 |
+
# Stem
|
| 147 |
+
x = layers.Conv2D(64, 7, strides=2, padding='same',
|
| 148 |
+
use_bias=False, name='conv1')(inputs)
|
| 149 |
+
x = group_norm(name='conv1_gn')(x)
|
| 150 |
+
x = layers.Activation('relu')(x)
|
| 151 |
+
x = layers.MaxPooling2D(3, strides=2, padding='same', name='pool1')(x)
|
| 152 |
+
|
| 153 |
+
# Stage 1 — conv2_x (output stride 4 → low-level features for decoder)
|
| 154 |
+
x = bottleneck_residual_block(x, 64, strides=1,
|
| 155 |
+
projection_shortcut=True,
|
| 156 |
+
name_prefix='conv2_block1')
|
| 157 |
+
x = bottleneck_residual_block(x, 64, name_prefix='conv2_block2')
|
| 158 |
+
low_level_features = bottleneck_residual_block(x, 64,
|
| 159 |
+
name_prefix='conv2_block3')
|
| 160 |
+
|
| 161 |
+
# Stage 2 — conv3_x (output stride 8)
|
| 162 |
+
x = bottleneck_residual_block(low_level_features, 128, strides=2,
|
| 163 |
+
projection_shortcut=True,
|
| 164 |
+
name_prefix='conv3_block1')
|
| 165 |
+
x = bottleneck_residual_block(x, 128, name_prefix='conv3_block2')
|
| 166 |
+
x = bottleneck_residual_block(x, 128, name_prefix='conv3_block3')
|
| 167 |
+
x = bottleneck_residual_block(x, 128, name_prefix='conv3_block4')
|
| 168 |
+
|
| 169 |
+
# Stage 3 — conv4_x (atrous rate=2, keeps stride at 8)
|
| 170 |
+
x = bottleneck_residual_block(x, 256, strides=1, dilation_rate=2,
|
| 171 |
+
projection_shortcut=True,
|
| 172 |
+
name_prefix='conv4_block1')
|
| 173 |
+
for i in range(2, 7):
|
| 174 |
+
x = bottleneck_residual_block(x, 256, dilation_rate=2,
|
| 175 |
+
name_prefix=f'conv4_block{i}')
|
| 176 |
+
|
| 177 |
+
# Stage 4 — conv5_x (atrous rate=4, keeps stride at 8)
|
| 178 |
+
x = bottleneck_residual_block(x, 512, strides=1, dilation_rate=4,
|
| 179 |
+
projection_shortcut=True,
|
| 180 |
+
name_prefix='conv5_block1')
|
| 181 |
+
x = bottleneck_residual_block(x, 512, dilation_rate=4,
|
| 182 |
+
name_prefix='conv5_block2')
|
| 183 |
+
x = bottleneck_residual_block(x, 512, dilation_rate=4,
|
| 184 |
+
name_prefix='conv5_block3')
|
| 185 |
+
|
| 186 |
+
# ==================================================================
|
| 187 |
+
# ASPP MODULE
|
| 188 |
+
# ==================================================================
|
| 189 |
+
x = aspp_block(x, filters=256)
|
| 190 |
+
|
| 191 |
+
# ==================================================================
|
| 192 |
+
# DECODER
|
| 193 |
+
# ==================================================================
|
| 194 |
+
|
| 195 |
+
# 4× upsample to reach low-level feature resolution (output stride 4)
|
| 196 |
+
x = layers.UpSampling2D(size=(4, 4), interpolation='bilinear',
|
| 197 |
+
name='decoder_upsample1')(x)
|
| 198 |
+
|
| 199 |
+
# Reduce low-level feature channels to 48 (as in the original paper)
|
| 200 |
+
low_level_features = layers.Conv2D(
|
| 201 |
+
48, 1, use_bias=False, name='decoder_low_level_conv'
|
| 202 |
+
)(low_level_features)
|
| 203 |
+
low_level_features = group_norm(name='decoder_low_level_gn')(low_level_features)
|
| 204 |
+
low_level_features = layers.Activation('relu')(low_level_features)
|
| 205 |
+
|
| 206 |
+
# Align spatial dims in case of any off-by-one from pooling
|
| 207 |
+
def match_spatial_dims(tensors):
|
| 208 |
+
high_level, low_level = tensors
|
| 209 |
+
low_shape = tf.shape(low_level)
|
| 210 |
+
return high_level[:, :low_shape[1], :low_shape[2], :], low_level
|
| 211 |
+
|
| 212 |
+
x_matched, low_matched = layers.Lambda(
|
| 213 |
+
match_spatial_dims, name='match_dims'
|
| 214 |
+
)([x, low_level_features])
|
| 215 |
+
|
| 216 |
+
# Fuse high-level and low-level features
|
| 217 |
+
x = layers.Concatenate(name='decoder_concat')([x_matched, low_matched])
|
| 218 |
+
|
| 219 |
+
x = layers.Conv2D(256, 3, padding='same', use_bias=False,
|
| 220 |
+
name='decoder_conv1')(x)
|
| 221 |
+
x = group_norm(name='decoder_conv1_gn')(x)
|
| 222 |
+
x = layers.Activation('relu')(x)
|
| 223 |
+
x = layers.Dropout(0.1, name='decoder_dropout1')(x)
|
| 224 |
+
|
| 225 |
+
x = layers.Conv2D(256, 3, padding='same', use_bias=False,
|
| 226 |
+
name='decoder_conv2')(x)
|
| 227 |
+
x = group_norm(name='decoder_conv2_gn')(x)
|
| 228 |
+
x = layers.Activation('relu')(x)
|
| 229 |
+
x = layers.Dropout(0.1, name='decoder_dropout2')(x)
|
| 230 |
+
|
| 231 |
+
# Final 4× upsample back to original resolution
|
| 232 |
+
x = layers.UpSampling2D(size=(4, 4), interpolation='bilinear',
|
| 233 |
+
name='decoder_upsample2')(x)
|
| 234 |
+
|
| 235 |
+
# ==================================================================
|
| 236 |
+
# OUTPUT
|
| 237 |
+
# ==================================================================
|
| 238 |
+
if num_classes == 1:
|
| 239 |
+
# Binary segmentation → sigmoid, single-channel mask
|
| 240 |
+
outputs = layers.Conv2D(1, 1, activation='sigmoid', name='output')(x)
|
| 241 |
+
else:
|
| 242 |
+
# Multi-class segmentation → softmax over num_classes channels
|
| 243 |
+
outputs = layers.Conv2D(num_classes, 1, activation='softmax',
|
| 244 |
+
name='output')(x)
|
| 245 |
+
|
| 246 |
+
model = keras.Model(inputs, outputs, name='DeepLabV3Plus_ResNet50_GN')
|
| 247 |
+
return model
|
models/for_WMH_Vent/model_training_scripts/p4_compute_class_weights.py
ADDED
|
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
P4 - Utility script to calculate inverse frequency weights for class balancing
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
python p4_compute_class_weights.py --fold 0 --scenario 4class --preprocessing standard
|
| 6 |
+
|
| 7 |
+
Output:
|
| 8 |
+
Saves class weights to JSON file for reproducibility
|
| 9 |
+
Prints weights for use in training
|
| 10 |
+
|
| 11 |
+
Authors:
|
| 12 |
+
"Mahdi Bashiri Bawil, Mousa Shamsi, Abolhassan Shakeri Bavil"
|
| 13 |
+
|
| 14 |
+
Developer:
|
| 15 |
+
"Mahdi Bashiri Bawil"
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import json
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
from tqdm import tqdm
|
| 22 |
+
import argparse
|
| 23 |
+
|
| 24 |
+
# Import data loader
|
| 25 |
+
from p4_data_loader import DataConfig, P2DataLoader
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def compute_class_frequencies(dataset, num_classes, total_samples=None):
|
| 29 |
+
"""
|
| 30 |
+
Compute class frequencies from dataset
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
dataset: TensorFlow dataset yielding (paired_input, target_mask)
|
| 34 |
+
num_classes: Number of classes (3 or 4)
|
| 35 |
+
total_samples: Total number of samples (for progress bar)
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
class_pixel_counts: Array of pixel counts per class
|
| 39 |
+
total_pixels: Total number of pixels analyzed
|
| 40 |
+
"""
|
| 41 |
+
class_pixel_counts = np.zeros(num_classes, dtype=np.int64)
|
| 42 |
+
total_pixels = 0
|
| 43 |
+
|
| 44 |
+
print(f"Computing class frequencies for {num_classes}-class scenario...")
|
| 45 |
+
|
| 46 |
+
iterator = tqdm(dataset, total=total_samples, desc="Processing") if total_samples else dataset
|
| 47 |
+
|
| 48 |
+
for paired_input, target_mask, _, _ in iterator:
|
| 49 |
+
# target_mask shape: (batch_size, 256, 256)
|
| 50 |
+
masks = target_mask.numpy()
|
| 51 |
+
|
| 52 |
+
for mask in masks:
|
| 53 |
+
# Count pixels for each class
|
| 54 |
+
for class_id in range(num_classes):
|
| 55 |
+
class_pixel_counts[class_id] += np.sum(mask == class_id)
|
| 56 |
+
|
| 57 |
+
total_pixels += mask.size
|
| 58 |
+
|
| 59 |
+
return class_pixel_counts, total_pixels
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def compute_inverse_frequency_weights(class_pixel_counts, num_classes):
|
| 63 |
+
"""
|
| 64 |
+
Compute inverse frequency weights with normalization
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
class_pixel_counts: Array of pixel counts per class
|
| 68 |
+
num_classes: Number of classes
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
class_weights: Normalized inverse frequency weights
|
| 72 |
+
class_frequencies: Class frequencies (for reference)
|
| 73 |
+
"""
|
| 74 |
+
total_pixels = np.sum(class_pixel_counts)
|
| 75 |
+
|
| 76 |
+
# Class frequencies
|
| 77 |
+
class_frequencies = class_pixel_counts / total_pixels
|
| 78 |
+
|
| 79 |
+
# Inverse frequency (with small epsilon to avoid division by zero)
|
| 80 |
+
epsilon = 1e-6
|
| 81 |
+
inverse_freq = 1.0 / (class_frequencies + epsilon)
|
| 82 |
+
|
| 83 |
+
# Normalize weights to sum = num_classes
|
| 84 |
+
# This keeps weights in a reasonable range while maintaining relative importance
|
| 85 |
+
class_weights = inverse_freq / np.sum(inverse_freq) * num_classes
|
| 86 |
+
|
| 87 |
+
return class_weights, class_frequencies
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def compute_and_save_class_weights(fold_id, class_scenario, preprocessing,
|
| 91 |
+
output_dir='class_weights'):
|
| 92 |
+
"""
|
| 93 |
+
Compute class weights for a specific fold and scenario
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
fold_id: Fold number (0-4)
|
| 97 |
+
class_scenario: '3class' or '4class'
|
| 98 |
+
preprocessing: 'standard' or 'zoomed'
|
| 99 |
+
output_dir: Directory to save weights
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
Dictionary with weights and statistics
|
| 103 |
+
"""
|
| 104 |
+
print("\n" + "="*70)
|
| 105 |
+
print(f"COMPUTING CLASS WEIGHTS")
|
| 106 |
+
print("="*70)
|
| 107 |
+
print(f"Fold: {fold_id}")
|
| 108 |
+
print(f"Scenario: {class_scenario}")
|
| 109 |
+
print(f"Preprocessing: {preprocessing}")
|
| 110 |
+
print("="*70 + "\n")
|
| 111 |
+
|
| 112 |
+
# Initialize data loader
|
| 113 |
+
config = DataConfig()
|
| 114 |
+
data_loader = P2DataLoader(config)
|
| 115 |
+
|
| 116 |
+
# Determine number of classes
|
| 117 |
+
num_classes = 3 if class_scenario == '3class' else 4
|
| 118 |
+
|
| 119 |
+
# Load training dataset
|
| 120 |
+
print("Loading training dataset...")
|
| 121 |
+
train_dataset = data_loader.create_dataset_for_fold(
|
| 122 |
+
fold_id=fold_id,
|
| 123 |
+
split='train',
|
| 124 |
+
preprocessing=preprocessing,
|
| 125 |
+
class_scenario=class_scenario,
|
| 126 |
+
batch_size=8, # Larger batch for faster processing
|
| 127 |
+
shuffle=False # No need to shuffle for counting
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
# Get dataset size
|
| 131 |
+
train_size = sum(1 for _ in train_dataset)
|
| 132 |
+
print(f"Training samples: {train_size}")
|
| 133 |
+
|
| 134 |
+
# Recreate dataset after consuming
|
| 135 |
+
train_dataset = data_loader.create_dataset_for_fold(
|
| 136 |
+
fold_id=fold_id,
|
| 137 |
+
split='train',
|
| 138 |
+
preprocessing=preprocessing,
|
| 139 |
+
class_scenario=class_scenario,
|
| 140 |
+
batch_size=8,
|
| 141 |
+
shuffle=False
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
# Compute class frequencies
|
| 145 |
+
class_pixel_counts, total_pixels = compute_class_frequencies(
|
| 146 |
+
train_dataset, num_classes, train_size
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
# Compute inverse frequency weights
|
| 150 |
+
class_weights, class_frequencies = compute_inverse_frequency_weights(
|
| 151 |
+
class_pixel_counts, num_classes
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
# Print results
|
| 155 |
+
print("\n" + "="*70)
|
| 156 |
+
print("RESULTS")
|
| 157 |
+
print("="*70)
|
| 158 |
+
|
| 159 |
+
class_names = {
|
| 160 |
+
3: ['Background', 'Ventricles', 'Abnormal WMH'],
|
| 161 |
+
4: ['Background', 'Ventricles', 'Normal WMH', 'Abnormal WMH']
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
print(f"\nTotal pixels analyzed: {total_pixels:,}")
|
| 165 |
+
print(f"\nClass Statistics:")
|
| 166 |
+
print("-" * 70)
|
| 167 |
+
|
| 168 |
+
for i in range(num_classes):
|
| 169 |
+
print(f"Class {i} ({class_names[num_classes][i]}):")
|
| 170 |
+
print(f" Pixel count: {class_pixel_counts[i]:,}")
|
| 171 |
+
print(f" Frequency: {class_frequencies[i]:.6f} ({class_frequencies[i]*100:.2f}%)")
|
| 172 |
+
print(f" Weight: {class_weights[i]:.4f}")
|
| 173 |
+
print()
|
| 174 |
+
|
| 175 |
+
# Save to JSON
|
| 176 |
+
output_path = Path(output_dir)
|
| 177 |
+
output_path.mkdir(exist_ok=True)
|
| 178 |
+
|
| 179 |
+
results = {
|
| 180 |
+
'fold_id': fold_id,
|
| 181 |
+
'class_scenario': class_scenario,
|
| 182 |
+
'preprocessing': preprocessing,
|
| 183 |
+
'num_classes': num_classes,
|
| 184 |
+
'total_pixels': int(total_pixels),
|
| 185 |
+
'class_pixel_counts': class_pixel_counts.tolist(),
|
| 186 |
+
'class_frequencies': class_frequencies.tolist(),
|
| 187 |
+
'class_weights': class_weights.tolist(),
|
| 188 |
+
'class_names': class_names[num_classes]
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
filename = f"class_weights_fold{fold_id}_{preprocessing}_{class_scenario}.json"
|
| 192 |
+
filepath = output_path / filename
|
| 193 |
+
|
| 194 |
+
with open(filepath, 'w') as f:
|
| 195 |
+
json.dump(results, f, indent=2)
|
| 196 |
+
|
| 197 |
+
print("="*70)
|
| 198 |
+
print(f"✅ Class weights saved to: {filepath}")
|
| 199 |
+
print("="*70)
|
| 200 |
+
|
| 201 |
+
# Print weights in format ready for code
|
| 202 |
+
print("\nFor use in training script:")
|
| 203 |
+
print("-" * 70)
|
| 204 |
+
print(f"class_weights = tf.constant({class_weights.tolist()}, dtype=tf.float32)")
|
| 205 |
+
print()
|
| 206 |
+
|
| 207 |
+
return results
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def compute_all_scenarios_for_fold(fold_id):
|
| 211 |
+
"""
|
| 212 |
+
Compute class weights for all 4 scenarios of a given fold
|
| 213 |
+
|
| 214 |
+
Args:
|
| 215 |
+
fold_id: Fold number (0-4)
|
| 216 |
+
"""
|
| 217 |
+
scenarios = [
|
| 218 |
+
{'preprocessing': 'standard', 'class_scenario': '3class'},
|
| 219 |
+
{'preprocessing': 'standard', 'class_scenario': '4class'},
|
| 220 |
+
{'preprocessing': 'zoomed', 'class_scenario': '3class'},
|
| 221 |
+
{'preprocessing': 'zoomed', 'class_scenario': '4class'},
|
| 222 |
+
]
|
| 223 |
+
|
| 224 |
+
all_results = {}
|
| 225 |
+
|
| 226 |
+
for scenario in scenarios:
|
| 227 |
+
results = compute_and_save_class_weights(
|
| 228 |
+
fold_id=fold_id,
|
| 229 |
+
class_scenario=scenario['class_scenario'],
|
| 230 |
+
preprocessing=scenario['preprocessing']
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
key = f"{scenario['preprocessing']}_{scenario['class_scenario']}"
|
| 234 |
+
all_results[key] = results
|
| 235 |
+
|
| 236 |
+
print("\n" + "="*70 + "\n")
|
| 237 |
+
|
| 238 |
+
return all_results
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def load_class_weights(fold_id, class_scenario, preprocessing, weights_dir='class_weights'):
|
| 242 |
+
"""
|
| 243 |
+
Load previously computed class weights
|
| 244 |
+
|
| 245 |
+
Args:
|
| 246 |
+
fold_id: Fold number (0-4)
|
| 247 |
+
class_scenario: '3class' or '4class'
|
| 248 |
+
preprocessing: 'standard' or 'zoomed'
|
| 249 |
+
weights_dir: Directory containing weights files
|
| 250 |
+
|
| 251 |
+
Returns:
|
| 252 |
+
class_weights: NumPy array of weights
|
| 253 |
+
"""
|
| 254 |
+
weights_path = Path(weights_dir)
|
| 255 |
+
filename = f"class_weights_fold{fold_id}_{preprocessing}_{class_scenario}.json"
|
| 256 |
+
filepath = weights_path / filename
|
| 257 |
+
|
| 258 |
+
if not filepath.exists():
|
| 259 |
+
raise FileNotFoundError(
|
| 260 |
+
f"Class weights not found: {filepath}\n"
|
| 261 |
+
f"Run compute_and_save_class_weights() first."
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
with open(filepath, 'r') as f:
|
| 265 |
+
results = json.load(f)
|
| 266 |
+
|
| 267 |
+
class_weights = np.array(results['class_weights'], dtype=np.float32)
|
| 268 |
+
|
| 269 |
+
return class_weights
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def main():
|
| 273 |
+
"""Main entry point with argument parsing"""
|
| 274 |
+
parser = argparse.ArgumentParser(
|
| 275 |
+
description='Compute class weights from training data',
|
| 276 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 277 |
+
epilog="""
|
| 278 |
+
Examples:
|
| 279 |
+
# Single scenario
|
| 280 |
+
python p2_compute_class_weights.py --fold 0 --scenario 4class --preprocessing standard
|
| 281 |
+
|
| 282 |
+
# All scenarios for one fold
|
| 283 |
+
python p2_compute_class_weights.py --fold 0 --all
|
| 284 |
+
|
| 285 |
+
# All folds (for completeness)
|
| 286 |
+
python p2_compute_class_weights.py --all-folds
|
| 287 |
+
"""
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
parser.add_argument(
|
| 291 |
+
'--fold',
|
| 292 |
+
type=int,
|
| 293 |
+
choices=[0, 1, 2, 3, 4],
|
| 294 |
+
help='Fold number (0-4)'
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
parser.add_argument(
|
| 298 |
+
'--scenario',
|
| 299 |
+
type=str,
|
| 300 |
+
choices=['3class', '4class'],
|
| 301 |
+
help='Class scenario'
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
parser.add_argument(
|
| 305 |
+
'--preprocessing',
|
| 306 |
+
type=str,
|
| 307 |
+
choices=['standard', 'zoomed'],
|
| 308 |
+
help='Preprocessing type'
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
parser.add_argument(
|
| 312 |
+
'--all',
|
| 313 |
+
action='store_true',
|
| 314 |
+
help='Compute for all scenarios of specified fold'
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
parser.add_argument(
|
| 318 |
+
'--all-folds',
|
| 319 |
+
action='store_true',
|
| 320 |
+
help='Compute for all scenarios of all folds'
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
args = parser.parse_args()
|
| 324 |
+
|
| 325 |
+
# Validate arguments
|
| 326 |
+
if args.all_folds:
|
| 327 |
+
# Compute for all folds
|
| 328 |
+
for fold_id in range(5):
|
| 329 |
+
print(f"\n{'='*70}")
|
| 330 |
+
print(f"PROCESSING FOLD {fold_id}")
|
| 331 |
+
print(f"{'='*70}\n")
|
| 332 |
+
compute_all_scenarios_for_fold(fold_id)
|
| 333 |
+
|
| 334 |
+
elif args.all:
|
| 335 |
+
# Compute all scenarios for one fold
|
| 336 |
+
if args.fold is None:
|
| 337 |
+
parser.error("--fold is required when using --all")
|
| 338 |
+
compute_all_scenarios_for_fold(args.fold)
|
| 339 |
+
|
| 340 |
+
else:
|
| 341 |
+
# Compute single scenario
|
| 342 |
+
if args.fold is None or args.scenario is None or args.preprocessing is None:
|
| 343 |
+
parser.error("--fold, --scenario, and --preprocessing are required")
|
| 344 |
+
|
| 345 |
+
compute_and_save_class_weights(
|
| 346 |
+
fold_id=args.fold,
|
| 347 |
+
class_scenario=args.scenario,
|
| 348 |
+
preprocessing=args.preprocessing
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
if __name__ == "__main__":
|
| 353 |
+
main()
|
models/for_WMH_Vent/model_training_scripts/p4_data_loader.py
ADDED
|
@@ -0,0 +1,912 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
P4 Article - Data Loading System
|
| 3 |
+
|
| 4 |
+
Complete implementation for brain segmentation experiments
|
| 5 |
+
|
| 6 |
+
WMH and Ventricles Segmentation with U-Net Models - Journal Paper Implementation
|
| 7 |
+
Three-class segmentation: Background vs Ventricles vs Abnormal WMH
|
| 8 |
+
Professional results saving and visualization for publication
|
| 9 |
+
|
| 10 |
+
This relates to our article:
|
| 11 |
+
"Deep Learning-Based Neuroanatomical Profiling Reveals Detailed Brain Changes:
|
| 12 |
+
A Large-Scale Multiple Sclerosis Study"
|
| 13 |
+
|
| 14 |
+
Features:
|
| 15 |
+
- Load FLAIR images and individual mask files from Cohort directory
|
| 16 |
+
- Support both Local_SAI (MS3SEG) and Public_MSSEG (MSSEG2016) datasets
|
| 17 |
+
- Handle standard and zoomed preprocessing variants
|
| 18 |
+
- Combine masks into 3-class or 4-class format
|
| 19 |
+
- Create paired inputs: [FLAIR | mask] concatenated (256x512)
|
| 20 |
+
- Patient-stratified K-fold cross-validation
|
| 21 |
+
- TensorFlow dataset creation with proper batching
|
| 22 |
+
|
| 23 |
+
Authors:
|
| 24 |
+
"Mahdi Bashiri Bawil, Mousa Shamsi, Abolhassan Shakeri Bavil"
|
| 25 |
+
|
| 26 |
+
Developer:
|
| 27 |
+
"Mahdi Bashiri Bawil"
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
import numpy as np
|
| 31 |
+
import os
|
| 32 |
+
from pathlib import Path
|
| 33 |
+
from typing import Tuple, List, Dict, Optional
|
| 34 |
+
import json
|
| 35 |
+
from sklearn.model_selection import KFold
|
| 36 |
+
from tqdm import tqdm
|
| 37 |
+
import cv2 as cv
|
| 38 |
+
|
| 39 |
+
# Deep Learning
|
| 40 |
+
import tensorflow as tf
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
###################### Configuration ######################
|
| 44 |
+
|
| 45 |
+
class DataConfig:
|
| 46 |
+
"""Data configuration for P4 experiments"""
|
| 47 |
+
|
| 48 |
+
def __init__(self):
|
| 49 |
+
# Base paths
|
| 50 |
+
self.cohort_dir = Path("/mnt/e/MBashiri/ours_articles/Paper#2/Data/Cohort") # CHANGE THIS to your actual path of Data Cohort
|
| 51 |
+
|
| 52 |
+
# Dataset configurations
|
| 53 |
+
self.datasets = {
|
| 54 |
+
'Local_SAI_updated': {
|
| 55 |
+
'base_path': self.cohort_dir / 'Local_SAI_updated',
|
| 56 |
+
'slice_range': (1, 20), # inclusive range 9,15
|
| 57 |
+
'patient_prefix_length': 6 # "101228"
|
| 58 |
+
},
|
| 59 |
+
'Public_MSSEG': {
|
| 60 |
+
'base_path': self.cohort_dir / 'Public_MSSEG',
|
| 61 |
+
'slice_range': (1, 50), # inclusive range 24,43
|
| 62 |
+
'patient_prefix_length': 6 # "c01p01"
|
| 63 |
+
}
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
# Preprocessing variants
|
| 67 |
+
self.preprocessing_types = ['standard', 'zoomed']
|
| 68 |
+
|
| 69 |
+
# Class scenarios
|
| 70 |
+
self.class_scenarios = {
|
| 71 |
+
'3class': {
|
| 72 |
+
'num_classes': 3,
|
| 73 |
+
'class_names': ['Background', 'Ventricles', 'Abnormal WMH'],
|
| 74 |
+
'description': 'Three-class: Background, Ventricles, Abnormal WMH',
|
| 75 |
+
'class_mapping': {
|
| 76 |
+
'background': 0,
|
| 77 |
+
'ventricles': 1,
|
| 78 |
+
'abnormal_wmh': 2,
|
| 79 |
+
}
|
| 80 |
+
},
|
| 81 |
+
'4class': {
|
| 82 |
+
'num_classes': 4,
|
| 83 |
+
'class_names': ['Background', 'Ventricles', 'Normal WMH', 'Abnormal WMH'],
|
| 84 |
+
'description': 'Four-class: Background, Ventricles, Normal WMH, Abnormal WMH',
|
| 85 |
+
'class_mapping': {
|
| 86 |
+
'background': 0,
|
| 87 |
+
'ventricles': 1,
|
| 88 |
+
'normal_wmh': 2,
|
| 89 |
+
'abnormal_wmh': 3
|
| 90 |
+
}
|
| 91 |
+
}
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
# K-fold parameters
|
| 95 |
+
self.k_folds = 4
|
| 96 |
+
self.test_split = 0.2 # 20% for test set
|
| 97 |
+
self.random_state = 42
|
| 98 |
+
|
| 99 |
+
# Image parameters
|
| 100 |
+
self.target_size = (256, 256)
|
| 101 |
+
self.paired_width = 512 # FLAIR (256) + mask (256)
|
| 102 |
+
|
| 103 |
+
# Paths for splits
|
| 104 |
+
self.splits_dir = Path("data_splits")
|
| 105 |
+
self.splits_file = self.splits_dir / "concat_fold_assignments.json"
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
###################### Helper Functions ######################
|
| 109 |
+
|
| 110 |
+
def extract_patient_id(filename: str, prefix_length: int = 6) -> str:
|
| 111 |
+
"""
|
| 112 |
+
Extract patient ID from filename
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
filename: e.g., "101228_5.npy" or "c01p01_25.png"
|
| 116 |
+
prefix_length: Number of characters in patient ID
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
Patient ID: e.g., "101228" or "c01p01"
|
| 120 |
+
"""
|
| 121 |
+
return filename.split('_')[0][:prefix_length]
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def extract_slice_number(filename: str) -> int:
|
| 125 |
+
"""
|
| 126 |
+
Extract slice number from filename
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
filename: e.g., "101228_5.npy" or "c01p01_25.png"
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
Slice number as integer
|
| 133 |
+
"""
|
| 134 |
+
# Get the part before file extension
|
| 135 |
+
basename = filename.split('.')[0]
|
| 136 |
+
# Get the last part after splitting by '_'
|
| 137 |
+
slice_num = basename.split('_')[-1]
|
| 138 |
+
return int(slice_num)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def load_flair_image(flair_path: Path, normalize: bool = False, of_z_score: bool = False) -> np.ndarray:
|
| 142 |
+
"""
|
| 143 |
+
Load FLAIR image (.png format)
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
flair_path: Path to .png file
|
| 147 |
+
normalize: Whether to apply z-score normalization
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
FLAIR image (256, 256, 1) as float32
|
| 151 |
+
"""
|
| 152 |
+
if of_z_score:
|
| 153 |
+
# Load NPY: the already z-scored FLAIR image data
|
| 154 |
+
flair = np.load(str(flair_path).replace('.png','.npy')).astype(np.float32)
|
| 155 |
+
else:
|
| 156 |
+
# Load PNG as grayscale
|
| 157 |
+
flair = cv.imread(str(flair_path), cv.IMREAD_GRAYSCALE).astype(np.float32)
|
| 158 |
+
|
| 159 |
+
# Normalize to [-1, 1]:
|
| 160 |
+
flair = (flair - np.min(flair)) / (np.max(flair) - np.min(flair))
|
| 161 |
+
flair = (2 * flair) - 1
|
| 162 |
+
|
| 163 |
+
# Ensure correct shape
|
| 164 |
+
if len(flair.shape) == 2:
|
| 165 |
+
flair = np.expand_dims(flair, axis=-1)
|
| 166 |
+
|
| 167 |
+
# Additional normalization if needed (should already be normalized)
|
| 168 |
+
if normalize and (np.std(flair) > 2.0 or np.abs(np.mean(flair)) > 1.0):
|
| 169 |
+
# Re-normalize if values seem off
|
| 170 |
+
flair = (flair - np.mean(flair)) / (np.std(flair) + 1e-7)
|
| 171 |
+
|
| 172 |
+
return flair
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def load_mask_image(mask_path: Path) -> np.ndarray:
|
| 176 |
+
"""
|
| 177 |
+
Load mask image (.png format)
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
mask_path: Path to .png file
|
| 181 |
+
|
| 182 |
+
Returns:
|
| 183 |
+
Binary mask (256, 256) as uint8
|
| 184 |
+
"""
|
| 185 |
+
# Load PNG as grayscale
|
| 186 |
+
mask = cv.imread(str(mask_path), cv.IMREAD_GRAYSCALE)
|
| 187 |
+
|
| 188 |
+
if mask is None:
|
| 189 |
+
raise FileNotFoundError(f"Could not load mask: {mask_path}")
|
| 190 |
+
|
| 191 |
+
# Binarize (any non-zero value becomes 1)
|
| 192 |
+
mask = (mask > 0).astype(np.uint8)
|
| 193 |
+
|
| 194 |
+
return mask
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def combine_masks(vent_mask: np.ndarray,
|
| 198 |
+
nwmh_mask: np.ndarray,
|
| 199 |
+
abwmh_mask: np.ndarray,
|
| 200 |
+
class_scenario: str,
|
| 201 |
+
preprocess: bool = False) -> np.ndarray:
|
| 202 |
+
"""
|
| 203 |
+
Combine individual masks into multi-class format
|
| 204 |
+
|
| 205 |
+
Args:
|
| 206 |
+
vent_mask: Ventricles mask (256, 256)
|
| 207 |
+
nwmh_mask: Normal WMH mask (256, 256)
|
| 208 |
+
abwmh_mask: Abnormal WMH mask (256, 256)
|
| 209 |
+
class_scenario: '3class' or '4class'
|
| 210 |
+
preprocess: Boolean turning the morphological preprocessing on or off
|
| 211 |
+
|
| 212 |
+
Returns:
|
| 213 |
+
Combined mask (256, 256) with class labels
|
| 214 |
+
"""
|
| 215 |
+
if preprocess:
|
| 216 |
+
from skimage.morphology import remove_small_objects, binary_erosion, binary_closing, binary_opening, disk, binary_dilation
|
| 217 |
+
min_object_size = 5
|
| 218 |
+
closing_kernel_size = 2
|
| 219 |
+
dilation_kernel_size = 1
|
| 220 |
+
|
| 221 |
+
vent_mask = vent_mask > 0
|
| 222 |
+
abwmh_mask = abwmh_mask > 0
|
| 223 |
+
nwmh_mask = nwmh_mask > 0
|
| 224 |
+
|
| 225 |
+
abwmh_mask = binary_closing(abwmh_mask, disk(closing_kernel_size))
|
| 226 |
+
abwmh_mask = binary_erosion(abwmh_mask, disk(dilation_kernel_size))
|
| 227 |
+
abwmh_mask = remove_small_objects(abwmh_mask, min_size=min_object_size)
|
| 228 |
+
|
| 229 |
+
nwmh_mask = binary_closing(nwmh_mask, disk(closing_kernel_size))
|
| 230 |
+
nwmh_mask = binary_erosion(nwmh_mask, disk(dilation_kernel_size))
|
| 231 |
+
nwmh_mask = remove_small_objects(nwmh_mask, min_size=min_object_size)
|
| 232 |
+
|
| 233 |
+
vent_mask = binary_closing(vent_mask, disk(closing_kernel_size))
|
| 234 |
+
vent_mask = binary_erosion(vent_mask, disk(dilation_kernel_size))
|
| 235 |
+
vent_mask = remove_small_objects(vent_mask, min_size=min_object_size)
|
| 236 |
+
|
| 237 |
+
abwmh_mask = abwmh_mask & ~vent_mask
|
| 238 |
+
nwmh_mask = nwmh_mask & ~vent_mask
|
| 239 |
+
abwmh_mask = abwmh_mask & ~nwmh_mask
|
| 240 |
+
|
| 241 |
+
if class_scenario == '3class':
|
| 242 |
+
# Class 0: Background (default)
|
| 243 |
+
# Class 1: Ventricles
|
| 244 |
+
# Class 2: Abnormal WMH
|
| 245 |
+
combined = np.zeros_like(vent_mask, dtype=np.uint8)
|
| 246 |
+
combined[vent_mask>0] = 1
|
| 247 |
+
combined[abwmh_mask>0] = 2
|
| 248 |
+
|
| 249 |
+
elif class_scenario == '4class':
|
| 250 |
+
# Class 0: Background (default)
|
| 251 |
+
# Class 1: Ventricles
|
| 252 |
+
# Class 2: Normal WMH
|
| 253 |
+
# Class 3: Abnormal WMH
|
| 254 |
+
combined = np.zeros_like(vent_mask, dtype=np.uint8)
|
| 255 |
+
combined[vent_mask>0] = 1
|
| 256 |
+
combined[nwmh_mask>0] = 2
|
| 257 |
+
combined[abwmh_mask>0] = 3
|
| 258 |
+
|
| 259 |
+
else:
|
| 260 |
+
raise ValueError(f"Unknown class_scenario: {class_scenario}")
|
| 261 |
+
|
| 262 |
+
return combined
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def is_valid_slice(vent_mask: np.ndarray,
|
| 266 |
+
nwmh_mask: np.ndarray,
|
| 267 |
+
abwmh_mask: np.ndarray) -> bool:
|
| 268 |
+
"""
|
| 269 |
+
Check if slice has at least one non-empty mask
|
| 270 |
+
|
| 271 |
+
Args:
|
| 272 |
+
vent_mask: Ventricles mask (256, 256)
|
| 273 |
+
nwmh_mask: Normal WMH mask (256, 256)
|
| 274 |
+
abwmh_mask: Abnormal WMH mask (256, 256)
|
| 275 |
+
|
| 276 |
+
Returns:
|
| 277 |
+
True if at least one mask has non-zero pixels
|
| 278 |
+
"""
|
| 279 |
+
has_ventricles = np.sum(vent_mask) > 50
|
| 280 |
+
has_nwmh = np.sum(nwmh_mask) > 50
|
| 281 |
+
has_abwmh = np.sum(abwmh_mask) > 50
|
| 282 |
+
|
| 283 |
+
# Valid if ANY mask has content
|
| 284 |
+
return True # or has_nwmh has_ventricles or has_abwmh #
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def create_paired_input(flair: np.ndarray,
|
| 288 |
+
mask: np.ndarray,
|
| 289 |
+
brain_mask: np.ndarray,
|
| 290 |
+
num_classes: np.ndarray,
|
| 291 |
+
if_bet=False) -> np.ndarray:
|
| 292 |
+
"""
|
| 293 |
+
Create paired input: [FLAIR | mask] concatenated horizontally
|
| 294 |
+
|
| 295 |
+
Args:
|
| 296 |
+
flair: FLAIR image (256, 256, 1) float32
|
| 297 |
+
mask: Combined mask (256, 256) uint8
|
| 298 |
+
|
| 299 |
+
Returns:
|
| 300 |
+
Paired image (256, 512, 1) float32
|
| 301 |
+
"""
|
| 302 |
+
# Binarize (any non-zero value becomes 1)
|
| 303 |
+
brain_mask = brain_mask > 0
|
| 304 |
+
|
| 305 |
+
# Brain extraction
|
| 306 |
+
if if_bet:
|
| 307 |
+
# print("\n\t Doing THEEEEEEEEE BET")
|
| 308 |
+
flair[~brain_mask] = np.min(flair)
|
| 309 |
+
mask[~brain_mask] = 0
|
| 310 |
+
|
| 311 |
+
# Ensure flair is 3D
|
| 312 |
+
if len(flair.shape) == 2:
|
| 313 |
+
flair = np.expand_dims(flair, axis=-1)
|
| 314 |
+
|
| 315 |
+
# Convert mask to float and normalize to [0, 1] range for consistency
|
| 316 |
+
# For 3-class: 0, 1, 2 -> -1, 0, 1.0
|
| 317 |
+
# For 4-class: 0, 1, 2, 3 -> -1, -0.333, 0.333, 1.0
|
| 318 |
+
max_class = num_classes
|
| 319 |
+
mask_normalized = mask.astype(np.float32)
|
| 320 |
+
if max_class > 0:
|
| 321 |
+
mask_normalized = mask_normalized / max_class
|
| 322 |
+
mask_normalized = (2 * mask_normalized) - 1
|
| 323 |
+
|
| 324 |
+
mask_3d = np.expand_dims(mask_normalized, axis=-1)
|
| 325 |
+
|
| 326 |
+
# Concatenate horizontally: [FLAIR | mask]
|
| 327 |
+
paired = np.concatenate([flair, mask_3d], axis=1) # (256, 512, 1)
|
| 328 |
+
|
| 329 |
+
return paired, mask
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
###################### Patient Stratified Splitting ######################
|
| 333 |
+
|
| 334 |
+
class PatientStratifiedSplitter:
|
| 335 |
+
"""
|
| 336 |
+
Create patient-stratified train/val/test splits
|
| 337 |
+
Similar to P6 implementation but adapted for P4 data structure
|
| 338 |
+
"""
|
| 339 |
+
|
| 340 |
+
def __init__(self, config: DataConfig):
|
| 341 |
+
self.config = config
|
| 342 |
+
self.config.splits_dir.mkdir(exist_ok=True)
|
| 343 |
+
|
| 344 |
+
def collect_all_patients(self) -> Dict[str, List[str]]:
|
| 345 |
+
"""
|
| 346 |
+
Collect all unique patient IDs from both datasets
|
| 347 |
+
|
| 348 |
+
Returns:
|
| 349 |
+
Dictionary mapping dataset_name -> list of patient IDs
|
| 350 |
+
"""
|
| 351 |
+
all_patients = {}
|
| 352 |
+
|
| 353 |
+
for dataset_name, dataset_config in self.config.datasets.items():
|
| 354 |
+
patients = set()
|
| 355 |
+
|
| 356 |
+
# Path to FLAIR images (standard preprocessing)
|
| 357 |
+
flair_dir = dataset_config['base_path'] / 'FLAIR' / 'Preprocessed' / 'images'
|
| 358 |
+
|
| 359 |
+
if not flair_dir.exists():
|
| 360 |
+
print(f"Warning: {flair_dir} does not exist. Skipping {dataset_name}.")
|
| 361 |
+
continue
|
| 362 |
+
|
| 363 |
+
# Collect all .png files
|
| 364 |
+
for flair_file in flair_dir.glob('*.png'):
|
| 365 |
+
patient_id = extract_patient_id(
|
| 366 |
+
flair_file.name,
|
| 367 |
+
dataset_config['patient_prefix_length']
|
| 368 |
+
)
|
| 369 |
+
patients.add(patient_id)
|
| 370 |
+
|
| 371 |
+
all_patients[dataset_name] = sorted(list(patients))
|
| 372 |
+
print(f"{dataset_name}: {len(all_patients[dataset_name])} patients")
|
| 373 |
+
|
| 374 |
+
return all_patients
|
| 375 |
+
|
| 376 |
+
def create_patient_stratified_splits(self,
|
| 377 |
+
save: bool = True) -> Dict:
|
| 378 |
+
"""
|
| 379 |
+
Create patient-stratified K-fold splits
|
| 380 |
+
|
| 381 |
+
Returns:
|
| 382 |
+
Dictionary containing fold assignments
|
| 383 |
+
"""
|
| 384 |
+
all_patients = self.collect_all_patients()
|
| 385 |
+
|
| 386 |
+
# Combine patients from both datasets
|
| 387 |
+
combined_patients = []
|
| 388 |
+
for dataset_name, patients in all_patients.items():
|
| 389 |
+
combined_patients.extend(patients)
|
| 390 |
+
|
| 391 |
+
combined_patients = np.array(combined_patients)
|
| 392 |
+
total_patients = len(combined_patients)
|
| 393 |
+
|
| 394 |
+
print(f"\nTotal unique patients: {total_patients}")
|
| 395 |
+
|
| 396 |
+
# Step 1: Split into train+val (80%) and test (20%)
|
| 397 |
+
np.random.seed(self.config.random_state)
|
| 398 |
+
test_size = int(total_patients * self.config.test_split)
|
| 399 |
+
|
| 400 |
+
test_indices = np.random.choice(
|
| 401 |
+
total_patients,
|
| 402 |
+
size=test_size,
|
| 403 |
+
replace=False
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
test_patients = combined_patients[test_indices]
|
| 407 |
+
train_val_indices = np.setdiff1d(np.arange(total_patients), test_indices)
|
| 408 |
+
train_val_patients = combined_patients[train_val_indices]
|
| 409 |
+
|
| 410 |
+
print(f"Test patients: {len(test_patients)}")
|
| 411 |
+
print(f"Train+Val patients: {len(train_val_patients)}")
|
| 412 |
+
|
| 413 |
+
# Step 2: Create K-fold splits on train+val patients
|
| 414 |
+
kfold = KFold(
|
| 415 |
+
n_splits=self.config.k_folds,
|
| 416 |
+
shuffle=True,
|
| 417 |
+
random_state=self.config.random_state
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
fold_assignments = {
|
| 421 |
+
'metadata': {
|
| 422 |
+
'total_patients': total_patients,
|
| 423 |
+
'test_patients': len(test_patients),
|
| 424 |
+
'trainval_patients': len(train_val_patients),
|
| 425 |
+
'n_folds': self.config.k_folds,
|
| 426 |
+
'random_seed': self.config.random_state,
|
| 427 |
+
'datasets': list(all_patients.keys())
|
| 428 |
+
},
|
| 429 |
+
'test_set': {
|
| 430 |
+
'patients': test_patients.tolist(),
|
| 431 |
+
'n_patients': len(test_patients)
|
| 432 |
+
},
|
| 433 |
+
'folds': {}
|
| 434 |
+
}
|
| 435 |
+
|
| 436 |
+
for fold_idx, (train_idx, val_idx) in enumerate(kfold.split(train_val_patients)):
|
| 437 |
+
train_patients_fold = train_val_patients[train_idx]
|
| 438 |
+
val_patients_fold = train_val_patients[val_idx]
|
| 439 |
+
|
| 440 |
+
fold_assignments['folds'][f'fold_{fold_idx}'] = {
|
| 441 |
+
'train_patients': train_patients_fold.tolist(),
|
| 442 |
+
'val_patients': val_patients_fold.tolist(),
|
| 443 |
+
'n_train': len(train_patients_fold),
|
| 444 |
+
'n_val': len(val_patients_fold)
|
| 445 |
+
}
|
| 446 |
+
|
| 447 |
+
print(f"Fold {fold_idx}: Train={len(train_patients_fold)}, Val={len(val_patients_fold)}")
|
| 448 |
+
|
| 449 |
+
# Save to JSON
|
| 450 |
+
if save:
|
| 451 |
+
with open(self.config.splits_file, 'w') as f:
|
| 452 |
+
json.dump(fold_assignments, f, indent=2)
|
| 453 |
+
print(f"\n✅ Fold assignments saved to: {self.config.splits_file}")
|
| 454 |
+
|
| 455 |
+
return fold_assignments
|
| 456 |
+
|
| 457 |
+
def load_fold_assignments(self) -> Dict:
|
| 458 |
+
"""Load existing fold assignments from JSON"""
|
| 459 |
+
if not self.config.splits_file.exists():
|
| 460 |
+
raise FileNotFoundError(
|
| 461 |
+
f"Fold assignments not found: {self.config.splits_file}\n"
|
| 462 |
+
f"Run create_patient_stratified_splits() first."
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
with open(self.config.splits_file, 'r') as f:
|
| 466 |
+
fold_assignments = json.load(f)
|
| 467 |
+
|
| 468 |
+
return fold_assignments
|
| 469 |
+
|
| 470 |
+
def verify_patient_separation(self, fold_assignments: Dict) -> bool:
|
| 471 |
+
"""
|
| 472 |
+
Verify no patient appears in multiple folds or in both train/val
|
| 473 |
+
Similar to P6's verification logic
|
| 474 |
+
"""
|
| 475 |
+
print("\n" + "="*60)
|
| 476 |
+
print("VERIFYING PATIENT SEPARATION")
|
| 477 |
+
print("="*60)
|
| 478 |
+
|
| 479 |
+
all_issues = []
|
| 480 |
+
test_patients = set(fold_assignments['test_set']['patients'])
|
| 481 |
+
|
| 482 |
+
# Check 1: No patient in both test and train/val
|
| 483 |
+
for fold_name, fold_data in fold_assignments['folds'].items():
|
| 484 |
+
train_patients = set(fold_data['train_patients'])
|
| 485 |
+
val_patients = set(fold_data['val_patients'])
|
| 486 |
+
|
| 487 |
+
test_train_overlap = test_patients.intersection(train_patients)
|
| 488 |
+
test_val_overlap = test_patients.intersection(val_patients)
|
| 489 |
+
|
| 490 |
+
if test_train_overlap:
|
| 491 |
+
issue = f"{fold_name}: Test-Train overlap: {test_train_overlap}"
|
| 492 |
+
all_issues.append(issue)
|
| 493 |
+
print(f"❌ {issue}")
|
| 494 |
+
|
| 495 |
+
if test_val_overlap:
|
| 496 |
+
issue = f"{fold_name}: Test-Val overlap: {test_val_overlap}"
|
| 497 |
+
all_issues.append(issue)
|
| 498 |
+
print(f"❌ {issue}")
|
| 499 |
+
|
| 500 |
+
# Check 2: No patient in both train and val within same fold
|
| 501 |
+
for fold_name, fold_data in fold_assignments['folds'].items():
|
| 502 |
+
train_patients = set(fold_data['train_patients'])
|
| 503 |
+
val_patients = set(fold_data['val_patients'])
|
| 504 |
+
|
| 505 |
+
train_val_overlap = train_patients.intersection(val_patients)
|
| 506 |
+
if train_val_overlap:
|
| 507 |
+
issue = f"{fold_name}: Train-Val overlap: {train_val_overlap}"
|
| 508 |
+
all_issues.append(issue)
|
| 509 |
+
print(f"❌ {issue}")
|
| 510 |
+
|
| 511 |
+
# Check 3: Each patient in validation exactly once
|
| 512 |
+
all_val_patients = []
|
| 513 |
+
for fold_data in fold_assignments['folds'].values():
|
| 514 |
+
all_val_patients.extend(fold_data['val_patients'])
|
| 515 |
+
|
| 516 |
+
val_patient_counts = {}
|
| 517 |
+
for patient in all_val_patients:
|
| 518 |
+
val_patient_counts[patient] = val_patient_counts.get(patient, 0) + 1
|
| 519 |
+
|
| 520 |
+
for patient, count in val_patient_counts.items():
|
| 521 |
+
if count != 1:
|
| 522 |
+
issue = f"Patient {patient} in validation {count} times (should be 1)"
|
| 523 |
+
all_issues.append(issue)
|
| 524 |
+
print(f"❌ {issue}")
|
| 525 |
+
|
| 526 |
+
if not all_issues:
|
| 527 |
+
print("✅ All patient separation checks passed")
|
| 528 |
+
print("✅ No data leakage detected")
|
| 529 |
+
return True
|
| 530 |
+
else:
|
| 531 |
+
print(f"\n❌ Found {len(all_issues)} issues")
|
| 532 |
+
return False
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
###################### Data Loader ######################
|
| 536 |
+
|
| 537 |
+
class P2DataLoader:
|
| 538 |
+
"""
|
| 539 |
+
Main data loader for P2 experiments
|
| 540 |
+
Handles loading FLAIR and masks, creating paired inputs, TensorFlow datasets
|
| 541 |
+
"""
|
| 542 |
+
|
| 543 |
+
def __init__(self, config: DataConfig):
|
| 544 |
+
self.config = config
|
| 545 |
+
|
| 546 |
+
def get_file_paths(self,
|
| 547 |
+
patient_id: str,
|
| 548 |
+
slice_num: int,
|
| 549 |
+
dataset_name: str,
|
| 550 |
+
preprocessing: str) -> Dict[str, Path]:
|
| 551 |
+
"""
|
| 552 |
+
Construct file paths for a given patient-slice
|
| 553 |
+
|
| 554 |
+
Args:
|
| 555 |
+
patient_id: e.g., "101228" or "c01p01"
|
| 556 |
+
slice_num: Slice number
|
| 557 |
+
dataset_name: 'Local_SAI_updated' or 'Public_MSSEG'
|
| 558 |
+
preprocessing: 'standard' or 'zoomed'
|
| 559 |
+
|
| 560 |
+
Returns:
|
| 561 |
+
Dictionary with paths to FLAIR and mask files
|
| 562 |
+
"""
|
| 563 |
+
dataset_config = self.config.datasets[dataset_name]
|
| 564 |
+
base_path = dataset_config['base_path']
|
| 565 |
+
|
| 566 |
+
# Determine subdirectory based on preprocessing
|
| 567 |
+
if preprocessing == 'standard':
|
| 568 |
+
flair_subdir = 'images'
|
| 569 |
+
gt_subdir = 'images'
|
| 570 |
+
else: # zoomed
|
| 571 |
+
flair_subdir = 'zoomed/images'
|
| 572 |
+
gt_subdir = 'zoomed/images'
|
| 573 |
+
|
| 574 |
+
# Construct paths
|
| 575 |
+
flair_path = base_path / 'FLAIR' / 'Preprocessed' / flair_subdir / f'{patient_id}_{slice_num}.png'
|
| 576 |
+
vent_path = base_path / 'GroundTruth' / gt_subdir / 'Vent_Masks' / f'{patient_id}_{slice_num}.png'
|
| 577 |
+
nwmh_path = base_path / 'GroundTruth' / gt_subdir / 'nWMH_Masks' / f'{patient_id}_{slice_num}.png'
|
| 578 |
+
abwmh_path = base_path / 'GroundTruth' / gt_subdir / 'abWMH_Masks' / f'{patient_id}_{slice_num}.png'
|
| 579 |
+
brain_path = base_path / 'GroundTruth' / gt_subdir / 'Brain_Masks' / f'{patient_id}_{slice_num}.png'
|
| 580 |
+
|
| 581 |
+
# Optional: zooming factors (only for zoomed preprocessing)
|
| 582 |
+
zoom_factors_path = None
|
| 583 |
+
if preprocessing == 'zoomed':
|
| 584 |
+
zoom_factors_path = base_path / 'FLAIR' / 'Preprocessed' / 'zoomed' / 'images' / f'{patient_id}_zooming_factors.npy'
|
| 585 |
+
|
| 586 |
+
return {
|
| 587 |
+
'flair': flair_path,
|
| 588 |
+
'vent_mask': vent_path,
|
| 589 |
+
'nwmh_mask': nwmh_path,
|
| 590 |
+
'abwmh_mask': abwmh_path,
|
| 591 |
+
'brain_mask': brain_path,
|
| 592 |
+
'zoom_factors': zoom_factors_path
|
| 593 |
+
}
|
| 594 |
+
|
| 595 |
+
def load_single_slice(self,
|
| 596 |
+
patient_id: str,
|
| 597 |
+
slice_num: int,
|
| 598 |
+
dataset_name: str,
|
| 599 |
+
preprocessing: str,
|
| 600 |
+
class_scenario: str,
|
| 601 |
+
of_z_score: bool = True,
|
| 602 |
+
if_bet: bool = True,
|
| 603 |
+
pre_morph: bool = False) -> Tuple[np.ndarray, np.ndarray]:
|
| 604 |
+
"""
|
| 605 |
+
Load a single patient-slice and create paired input
|
| 606 |
+
|
| 607 |
+
Args:
|
| 608 |
+
patient_id: Patient identifier
|
| 609 |
+
slice_num: Slice number
|
| 610 |
+
dataset_name: 'Local_SAI_updated' or 'Public_MSSEG'
|
| 611 |
+
preprocessing: 'standard' or 'zoomed'
|
| 612 |
+
class_scenario: '3class' or '4class'
|
| 613 |
+
|
| 614 |
+
Returns:
|
| 615 |
+
Tuple of (paired_input, combined_mask)
|
| 616 |
+
- paired_input: (256, 512, 1) FLAIR + mask concatenated
|
| 617 |
+
- combined_mask: (256, 256) multi-class labels
|
| 618 |
+
"""
|
| 619 |
+
# Class number
|
| 620 |
+
num_classes = int(class_scenario[0]) - 1
|
| 621 |
+
|
| 622 |
+
# Get file paths
|
| 623 |
+
paths = self.get_file_paths(patient_id, slice_num, dataset_name, preprocessing)
|
| 624 |
+
|
| 625 |
+
# Load FLAIR
|
| 626 |
+
flair = load_flair_image(paths['flair'], of_z_score=of_z_score)
|
| 627 |
+
|
| 628 |
+
# Load masks
|
| 629 |
+
vent_mask = load_mask_image(paths['vent_mask'])
|
| 630 |
+
nwmh_mask = load_mask_image(paths['nwmh_mask'])
|
| 631 |
+
abwmh_mask = load_mask_image(paths['abwmh_mask'])
|
| 632 |
+
brain_mask = load_mask_image(paths['brain_mask'])
|
| 633 |
+
|
| 634 |
+
# Combine masks
|
| 635 |
+
combined_mask = combine_masks(vent_mask, nwmh_mask, abwmh_mask, class_scenario, preprocess=pre_morph)
|
| 636 |
+
|
| 637 |
+
# Create paired input
|
| 638 |
+
paired_input, combined_mask = create_paired_input(flair, combined_mask, brain_mask, num_classes=num_classes, if_bet=if_bet)
|
| 639 |
+
|
| 640 |
+
return paired_input, combined_mask
|
| 641 |
+
|
| 642 |
+
def collect_patient_slices(self,
|
| 643 |
+
patient_list: List[str],
|
| 644 |
+
dataset_name: str,
|
| 645 |
+
preprocessing: str) -> List[Tuple[str, int, str]]:
|
| 646 |
+
"""
|
| 647 |
+
Collect all valid slice files for given patients
|
| 648 |
+
FILTERS OUT SLICES WITH ALL EMPTY MASKS
|
| 649 |
+
|
| 650 |
+
Args:
|
| 651 |
+
patient_list: List of patient IDs
|
| 652 |
+
dataset_name: 'Local_SAI_updated' or 'Public_MSSEG'
|
| 653 |
+
preprocessing: 'standard' or 'zoomed'
|
| 654 |
+
|
| 655 |
+
Returns:
|
| 656 |
+
List of tuples (patient_id, slice_num, dataset_name)
|
| 657 |
+
"""
|
| 658 |
+
dataset_config = self.config.datasets[dataset_name]
|
| 659 |
+
slice_min, slice_max = dataset_config['slice_range']
|
| 660 |
+
|
| 661 |
+
patient_slices = []
|
| 662 |
+
skipped_empty = 0
|
| 663 |
+
|
| 664 |
+
for patient_id in patient_list:
|
| 665 |
+
# Check which dataset this patient belongs to
|
| 666 |
+
# Try to find patient in current dataset
|
| 667 |
+
for slice_num in range(slice_min, slice_max + 1):
|
| 668 |
+
paths = self.get_file_paths(patient_id, slice_num, dataset_name, preprocessing)
|
| 669 |
+
|
| 670 |
+
# Check if all required files exist
|
| 671 |
+
if (paths['flair'].exists() and
|
| 672 |
+
paths['vent_mask'].exists() and
|
| 673 |
+
paths['nwmh_mask'].exists() and
|
| 674 |
+
paths['abwmh_mask'].exists() and
|
| 675 |
+
paths['brain_mask'].exists()):
|
| 676 |
+
|
| 677 |
+
# VALIDATION: Check if masks are not all empty
|
| 678 |
+
try:
|
| 679 |
+
vent_mask = load_mask_image(paths['vent_mask'])
|
| 680 |
+
nwmh_mask = load_mask_image(paths['nwmh_mask'])
|
| 681 |
+
abwmh_mask = load_mask_image(paths['abwmh_mask'])
|
| 682 |
+
brain_mask = load_mask_image(paths['brain_mask'])
|
| 683 |
+
|
| 684 |
+
# Only add if at least one mask has content
|
| 685 |
+
if is_valid_slice(vent_mask, nwmh_mask, abwmh_mask):
|
| 686 |
+
patient_slices.append((patient_id, slice_num, dataset_name))
|
| 687 |
+
else:
|
| 688 |
+
skipped_empty += 1
|
| 689 |
+
|
| 690 |
+
except Exception as e:
|
| 691 |
+
print(f"Warning: Could not validate {patient_id}_{slice_num}: {e}")
|
| 692 |
+
skipped_empty += 1
|
| 693 |
+
|
| 694 |
+
if skipped_empty > 0:
|
| 695 |
+
print(f" ⚠️ Skipped {skipped_empty} slices with empty masks")
|
| 696 |
+
|
| 697 |
+
return patient_slices
|
| 698 |
+
|
| 699 |
+
def create_dataset_for_fold(self,
|
| 700 |
+
fold_id: int,
|
| 701 |
+
split: str,
|
| 702 |
+
preprocessing: str,
|
| 703 |
+
class_scenario: str,
|
| 704 |
+
batch_size: int = 1,
|
| 705 |
+
shuffle: bool = True,
|
| 706 |
+
use_z_scored: bool = True,
|
| 707 |
+
bet: bool = False) -> tf.data.Dataset:
|
| 708 |
+
"""
|
| 709 |
+
Create TensorFlow dataset for a specific fold and split
|
| 710 |
+
|
| 711 |
+
Args:
|
| 712 |
+
fold_id: Fold number (0-4)
|
| 713 |
+
split: 'train', 'val', or 'test'
|
| 714 |
+
preprocessing: 'standard' or 'zoomed'
|
| 715 |
+
class_scenario: '3class' or '4class'
|
| 716 |
+
batch_size: Batch size
|
| 717 |
+
shuffle: Whether to shuffle data
|
| 718 |
+
|
| 719 |
+
Returns:
|
| 720 |
+
tf.data.Dataset yielding (paired_input, combined_mask) batches
|
| 721 |
+
"""
|
| 722 |
+
# Load fold assignments
|
| 723 |
+
splitter = PatientStratifiedSplitter(self.config)
|
| 724 |
+
fold_assignments = splitter.load_fold_assignments()
|
| 725 |
+
|
| 726 |
+
# Get patient list for this split
|
| 727 |
+
if split == 'test':
|
| 728 |
+
patient_list = fold_assignments['test_set']['patients']
|
| 729 |
+
else:
|
| 730 |
+
fold_key = f'fold_{fold_id}'
|
| 731 |
+
if split == 'train':
|
| 732 |
+
patient_list = fold_assignments['folds'][fold_key]['train_patients']
|
| 733 |
+
elif split == 'val':
|
| 734 |
+
patient_list = fold_assignments['folds'][fold_key]['val_patients']
|
| 735 |
+
else:
|
| 736 |
+
raise ValueError(f"Unknown split: {split}")
|
| 737 |
+
|
| 738 |
+
print(f"\nCreating dataset for fold {fold_id}, split '{split}'")
|
| 739 |
+
print(f"Patients: {len(patient_list)}")
|
| 740 |
+
|
| 741 |
+
# Collect all patient-slices from both datasets
|
| 742 |
+
all_patient_slices = []
|
| 743 |
+
|
| 744 |
+
for dataset_name in self.config.datasets.keys():
|
| 745 |
+
# Filter patient list to only include patients from this dataset
|
| 746 |
+
# This is done by checking patient ID prefix
|
| 747 |
+
dataset_patients = [p for p in patient_list]
|
| 748 |
+
|
| 749 |
+
patient_slices = self.collect_patient_slices(
|
| 750 |
+
dataset_patients,
|
| 751 |
+
dataset_name,
|
| 752 |
+
preprocessing
|
| 753 |
+
)
|
| 754 |
+
all_patient_slices.extend(patient_slices)
|
| 755 |
+
|
| 756 |
+
print(f"Total slices: {len(all_patient_slices)}")
|
| 757 |
+
|
| 758 |
+
if len(all_patient_slices) == 0:
|
| 759 |
+
raise ValueError(f"No data found for fold {fold_id}, split '{split}'")
|
| 760 |
+
|
| 761 |
+
# Create TensorFlow dataset
|
| 762 |
+
def data_generator():
|
| 763 |
+
"""Generator function for tf.data.Dataset"""
|
| 764 |
+
for patient_id, slice_num, dataset_name in all_patient_slices:
|
| 765 |
+
try:
|
| 766 |
+
paired_input, combined_mask = self.load_single_slice(
|
| 767 |
+
patient_id, slice_num, dataset_name,
|
| 768 |
+
preprocessing, class_scenario
|
| 769 |
+
)
|
| 770 |
+
yield paired_input, combined_mask, patient_id, slice_num
|
| 771 |
+
except Exception as e:
|
| 772 |
+
print(f"Error loading {patient_id}_{slice_num}: {e}")
|
| 773 |
+
continue
|
| 774 |
+
|
| 775 |
+
# Create dataset
|
| 776 |
+
dataset = tf.data.Dataset.from_generator(
|
| 777 |
+
data_generator,
|
| 778 |
+
output_signature=(
|
| 779 |
+
tf.TensorSpec(shape=(256, 512, 1), dtype=tf.float32), # concatenated image
|
| 780 |
+
tf.TensorSpec(shape=(256, 256), dtype=tf.uint8), # multi-level mask
|
| 781 |
+
tf.TensorSpec(shape=(), dtype=tf.string), # patient_id
|
| 782 |
+
tf.TensorSpec(shape=(), dtype=tf.int32) # slice_num
|
| 783 |
+
)
|
| 784 |
+
)
|
| 785 |
+
|
| 786 |
+
# ── Cache BEFORE shuffle/batch ──────────────────────────────────────
|
| 787 |
+
# On epoch 1 the generator runs once and all 700 samples are stored
|
| 788 |
+
# in RAM (~350 MB). From epoch 2 onward no disk I/O occurs at all.
|
| 789 |
+
# Placing cache HERE (on unbatched, unshuffled samples) means:
|
| 790 |
+
# • The expensive load/decode/combine step is paid only once.
|
| 791 |
+
# • Shuffle re-randomises the order freshly each epoch (because
|
| 792 |
+
# reshuffle_each_iteration=True is the default).
|
| 793 |
+
# • Batch composition therefore differs every epoch as desired.
|
| 794 |
+
dataset = dataset.cache()
|
| 795 |
+
|
| 796 |
+
# Shuffle if training (acts on the in-RAM cache every epoch)
|
| 797 |
+
if shuffle and split == 'train':
|
| 798 |
+
dataset = dataset.shuffle(
|
| 799 |
+
buffer_size=len(all_patient_slices),
|
| 800 |
+
reshuffle_each_iteration=True # new random order each epoch
|
| 801 |
+
)
|
| 802 |
+
|
| 803 |
+
# Batch and prefetch
|
| 804 |
+
dataset = dataset.batch(batch_size)
|
| 805 |
+
dataset = dataset.prefetch(tf.data.AUTOTUNE)
|
| 806 |
+
|
| 807 |
+
return dataset
|
| 808 |
+
|
| 809 |
+
|
| 810 |
+
###################### Testing & Validation Functions ######################
|
| 811 |
+
|
| 812 |
+
def test_data_loading():
|
| 813 |
+
"""Test data loading functionality"""
|
| 814 |
+
print("\n" + "="*60)
|
| 815 |
+
print("TESTING DATA LOADING")
|
| 816 |
+
print("="*60)
|
| 817 |
+
|
| 818 |
+
config = DataConfig()
|
| 819 |
+
|
| 820 |
+
# Test 1: Create fold assignments
|
| 821 |
+
print("\n[TEST 1] Creating patient stratified splits...")
|
| 822 |
+
splitter = PatientStratifiedSplitter(config)
|
| 823 |
+
fold_assignments = splitter.create_patient_stratified_splits(save=True)
|
| 824 |
+
|
| 825 |
+
# Verify patient separation
|
| 826 |
+
is_valid = splitter.verify_patient_separation(fold_assignments)
|
| 827 |
+
|
| 828 |
+
if not is_valid:
|
| 829 |
+
print("❌ Patient separation verification failed!")
|
| 830 |
+
return False
|
| 831 |
+
|
| 832 |
+
# Test 2: Load a single slice
|
| 833 |
+
print("\n[TEST 2] Loading single slice...")
|
| 834 |
+
loader = P2DataLoader(config)
|
| 835 |
+
|
| 836 |
+
# Get a test patient from fold 0 train set
|
| 837 |
+
test_patient = fold_assignments['folds']['fold_0']['train_patients'][0]
|
| 838 |
+
|
| 839 |
+
# Determine which dataset this patient belongs to
|
| 840 |
+
if test_patient.startswith('c'):
|
| 841 |
+
test_dataset = 'Public_MSSEG'
|
| 842 |
+
test_slice = 25 # Middle of 20-46 range
|
| 843 |
+
else:
|
| 844 |
+
test_dataset = 'Local_SAI_updated'
|
| 845 |
+
test_slice = 10 # Middle of 8-15 range
|
| 846 |
+
|
| 847 |
+
try:
|
| 848 |
+
paired_input, combined_mask = loader.load_single_slice(
|
| 849 |
+
test_patient, test_slice, test_dataset,
|
| 850 |
+
'standard', '4class'
|
| 851 |
+
)
|
| 852 |
+
|
| 853 |
+
print(f"✅ Loaded slice {test_patient}_{test_slice}")
|
| 854 |
+
print(f" Paired input shape: {paired_input.shape}")
|
| 855 |
+
print(f" Combined mask shape: {combined_mask.shape}")
|
| 856 |
+
print(f" Mask unique values: {np.unique(combined_mask)}")
|
| 857 |
+
|
| 858 |
+
except Exception as e:
|
| 859 |
+
print(f"❌ Failed to load slice: {e}")
|
| 860 |
+
return False
|
| 861 |
+
|
| 862 |
+
# Test 3: Create TensorFlow dataset
|
| 863 |
+
print("\n[TEST 3] Creating TensorFlow dataset...")
|
| 864 |
+
try:
|
| 865 |
+
dataset = loader.create_dataset_for_fold(
|
| 866 |
+
fold_id=0,
|
| 867 |
+
split='train',
|
| 868 |
+
preprocessing='standard',
|
| 869 |
+
class_scenario='4class',
|
| 870 |
+
batch_size=2,
|
| 871 |
+
shuffle=True
|
| 872 |
+
)
|
| 873 |
+
|
| 874 |
+
# Get first batch
|
| 875 |
+
for batch_paired, batch_masks in dataset.take(1):
|
| 876 |
+
print(f"✅ Created dataset")
|
| 877 |
+
print(f" Batch paired input shape: {batch_paired.shape}")
|
| 878 |
+
print(f" Batch masks shape: {batch_masks.shape}")
|
| 879 |
+
print(f" Paired input dtype: {batch_paired.dtype}")
|
| 880 |
+
print(f" Masks dtype: {batch_masks.dtype}")
|
| 881 |
+
|
| 882 |
+
except Exception as e:
|
| 883 |
+
print(f"❌ Failed to create dataset: {e}")
|
| 884 |
+
return False
|
| 885 |
+
|
| 886 |
+
print("\n" + "="*60)
|
| 887 |
+
print("✅ ALL TESTS PASSED")
|
| 888 |
+
print("="*60)
|
| 889 |
+
|
| 890 |
+
return True
|
| 891 |
+
|
| 892 |
+
|
| 893 |
+
###################### Main Execution ######################
|
| 894 |
+
|
| 895 |
+
if __name__ == "__main__":
|
| 896 |
+
# Run tests
|
| 897 |
+
success = test_data_loading()
|
| 898 |
+
|
| 899 |
+
if success:
|
| 900 |
+
print("\n" + "="*60)
|
| 901 |
+
print("DATA LOADER READY FOR USE")
|
| 902 |
+
print("="*60)
|
| 903 |
+
print("\nNext steps:")
|
| 904 |
+
print("1. Verify fold_assignments.json created in data_splits/")
|
| 905 |
+
print("2. Check that all file paths are correct for your system")
|
| 906 |
+
print("3. Proceed to model implementation")
|
| 907 |
+
else:
|
| 908 |
+
print("\n" + "="*60)
|
| 909 |
+
print("❌ DATA LOADER TESTS FAILED")
|
| 910 |
+
print("="*60)
|
| 911 |
+
print("\nPlease fix the issues above before proceeding")
|
| 912 |
+
|
models/for_WMH_Vent/model_training_scripts/p4_error_analysis.py
ADDED
|
@@ -0,0 +1,1033 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
P2 Article - Error Analysis & Hard Case Ranking Module
|
| 3 |
+
for Ventricles and WMH Segmentation
|
| 4 |
+
|
| 5 |
+
Integrates with p4_inference.py to identify problematic slices and patients,
|
| 6 |
+
rank them by difficulty, and produce rich diagnostic visualizations.
|
| 7 |
+
|
| 8 |
+
Developer: Mahdi Bashiri Bawil
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import matplotlib.pyplot as plt
|
| 13 |
+
import matplotlib.gridspec as gridspec
|
| 14 |
+
import matplotlib.patches as mpatches
|
| 15 |
+
from matplotlib.colors import ListedColormap, BoundaryNorm
|
| 16 |
+
import pandas as pd
|
| 17 |
+
import json
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from collections import defaultdict
|
| 20 |
+
from scipy.ndimage import binary_erosion, label as scipy_label
|
| 21 |
+
from tqdm import tqdm
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 25 |
+
# SECTION 1 — Slice-level metric computation
|
| 26 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 27 |
+
|
| 28 |
+
def _dice_binary(gt_bin, pred_bin):
|
| 29 |
+
"""Dice for a single binary mask pair. Returns NaN if both are empty."""
|
| 30 |
+
tp = np.sum(gt_bin & pred_bin)
|
| 31 |
+
denom = np.sum(gt_bin) + np.sum(pred_bin)
|
| 32 |
+
if denom == 0:
|
| 33 |
+
return np.nan # class truly absent — not a failure
|
| 34 |
+
return float(2 * tp / (denom + 1e-7))
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _iou_binary(gt_bin, pred_bin):
|
| 38 |
+
tp = np.sum(gt_bin & pred_bin)
|
| 39 |
+
denom = np.sum(gt_bin | pred_bin)
|
| 40 |
+
if denom == 0:
|
| 41 |
+
return np.nan
|
| 42 |
+
return float(tp / (denom + 1e-7))
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _precision_recall(gt_bin, pred_bin):
|
| 46 |
+
tp = np.sum(gt_bin & pred_bin)
|
| 47 |
+
fp = np.sum(~gt_bin & pred_bin)
|
| 48 |
+
fn = np.sum(gt_bin & ~pred_bin)
|
| 49 |
+
precision = float(tp / (tp + fp + 1e-7))
|
| 50 |
+
recall = float(tp / (tp + fn + 1e-7))
|
| 51 |
+
return precision, recall
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _false_positive_volume(gt_bin, pred_bin):
|
| 55 |
+
"""Fraction of predicted pixels that are false positives."""
|
| 56 |
+
fp = np.sum(~gt_bin & pred_bin)
|
| 57 |
+
total_pred = np.sum(pred_bin)
|
| 58 |
+
if total_pred == 0:
|
| 59 |
+
return 0.0
|
| 60 |
+
return float(fp / total_pred)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _false_negative_volume(gt_bin, pred_bin):
|
| 64 |
+
"""Fraction of GT pixels that are missed."""
|
| 65 |
+
fn = np.sum(gt_bin & ~pred_bin)
|
| 66 |
+
total_gt = np.sum(gt_bin)
|
| 67 |
+
if total_gt == 0:
|
| 68 |
+
return 0.0
|
| 69 |
+
return float(fn / total_gt)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def _gt_load(gt_hw, class_idx):
|
| 73 |
+
"""Return binary GT mask for a specific class from a (H,W) label map."""
|
| 74 |
+
return gt_hw == class_idx
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def _pred_load(pred_hw, class_idx):
|
| 78 |
+
return pred_hw == class_idx
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def compute_slice_metrics(gt_hw, pred_hw, num_classes, class_names,
|
| 82 |
+
mean_confidence=None):
|
| 83 |
+
"""
|
| 84 |
+
Compute per-class and summary metrics for a single 2-D slice.
|
| 85 |
+
|
| 86 |
+
Parameters
|
| 87 |
+
----------
|
| 88 |
+
gt_hw : np.ndarray (H, W) — integer label map (ground truth)
|
| 89 |
+
pred_hw : np.ndarray (H, W) — integer label map (prediction)
|
| 90 |
+
num_classes : int
|
| 91 |
+
class_names : list[str]
|
| 92 |
+
mean_confidence : float | None — mean max-softmax probability for the slice
|
| 93 |
+
|
| 94 |
+
Returns
|
| 95 |
+
-------
|
| 96 |
+
dict with per-class and aggregate metrics
|
| 97 |
+
"""
|
| 98 |
+
results = {}
|
| 99 |
+
dice_values = []
|
| 100 |
+
iou_values = []
|
| 101 |
+
|
| 102 |
+
for cls in range(num_classes):
|
| 103 |
+
gt_bin = _gt_load(gt_hw, cls)
|
| 104 |
+
pred_bin = _pred_load(pred_hw, cls)
|
| 105 |
+
|
| 106 |
+
dice = _dice_binary(gt_bin, pred_bin)
|
| 107 |
+
iou = _iou_binary(gt_bin, pred_bin)
|
| 108 |
+
prec, rec = _precision_recall(gt_bin, pred_bin)
|
| 109 |
+
fpr = _false_positive_volume(gt_bin, pred_bin)
|
| 110 |
+
fnr = _false_negative_volume(gt_bin, pred_bin)
|
| 111 |
+
|
| 112 |
+
gt_px = int(np.sum(gt_bin))
|
| 113 |
+
pred_px = int(np.sum(pred_bin))
|
| 114 |
+
error_px = int(np.sum(gt_bin != pred_bin))
|
| 115 |
+
|
| 116 |
+
results[class_names[cls]] = {
|
| 117 |
+
'dice': dice,
|
| 118 |
+
'iou': iou,
|
| 119 |
+
'precision': prec,
|
| 120 |
+
'recall': rec,
|
| 121 |
+
'fp_rate': fpr,
|
| 122 |
+
'fn_rate': fnr,
|
| 123 |
+
'gt_pixels': gt_px,
|
| 124 |
+
'pred_pixels': pred_px,
|
| 125 |
+
'error_pixels': error_px,
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
if not np.isnan(dice):
|
| 129 |
+
dice_values.append(dice)
|
| 130 |
+
if not np.isnan(iou):
|
| 131 |
+
iou_values.append(iou)
|
| 132 |
+
|
| 133 |
+
# Pixel-level error rate (ignoring class)
|
| 134 |
+
total_px = gt_hw.size
|
| 135 |
+
wrong_px = int(np.sum(gt_hw != pred_hw))
|
| 136 |
+
error_rate = wrong_px / total_px
|
| 137 |
+
|
| 138 |
+
# Focus on foreground classes only (skip background=0) for composite score
|
| 139 |
+
fg_dice = []
|
| 140 |
+
for cls in range(1, num_classes):
|
| 141 |
+
d = results[class_names[cls]]['dice']
|
| 142 |
+
if not np.isnan(d):
|
| 143 |
+
fg_dice.append(d)
|
| 144 |
+
|
| 145 |
+
mean_fg_dice = float(np.mean(fg_dice)) if fg_dice else np.nan
|
| 146 |
+
min_fg_dice = float(np.min(fg_dice)) if fg_dice else np.nan
|
| 147 |
+
|
| 148 |
+
results['_summary'] = {
|
| 149 |
+
'error_rate': error_rate,
|
| 150 |
+
'wrong_pixels': wrong_px,
|
| 151 |
+
'total_pixels': total_px,
|
| 152 |
+
'mean_fg_dice': mean_fg_dice,
|
| 153 |
+
'min_fg_dice': min_fg_dice,
|
| 154 |
+
'mean_confidence': mean_confidence,
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
return results
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 161 |
+
# SECTION 2 — Build slice-level and patient-level tables
|
| 162 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 163 |
+
|
| 164 |
+
def build_error_tables(patient_results, num_classes, class_names):
|
| 165 |
+
"""
|
| 166 |
+
Iterate over all patients / slices stored in patient_results
|
| 167 |
+
(the dict returned by run_inference) and build:
|
| 168 |
+
|
| 169 |
+
- slice_records : list of dicts, one per 2-D slice
|
| 170 |
+
- patient_records : list of dicts, one per patient (aggregated)
|
| 171 |
+
|
| 172 |
+
Parameters
|
| 173 |
+
----------
|
| 174 |
+
patient_results : dict
|
| 175 |
+
{patient_id: {'predictions', 'ground_truths', 'probabilities',
|
| 176 |
+
'flairs', 'slice_indices'}}
|
| 177 |
+
num_classes : int
|
| 178 |
+
class_names : list[str]
|
| 179 |
+
|
| 180 |
+
Returns
|
| 181 |
+
-------
|
| 182 |
+
slice_df : pd.DataFrame
|
| 183 |
+
patient_df : pd.DataFrame
|
| 184 |
+
"""
|
| 185 |
+
slice_records = []
|
| 186 |
+
patient_records = []
|
| 187 |
+
|
| 188 |
+
for patient_id, data in tqdm(patient_results.items(),
|
| 189 |
+
desc="Building error tables"):
|
| 190 |
+
order = np.argsort(data['slice_indices'])
|
| 191 |
+
|
| 192 |
+
preds = np.array(data['predictions'])[order] # (S, H, W)
|
| 193 |
+
gts = np.array(data['ground_truths'])[order] # (S, H, W, C) or (S, H, W)
|
| 194 |
+
probs = np.array(data['probabilities'])[order] # (S, H, W)
|
| 195 |
+
slices = np.array(data['slice_indices'])[order] # (S,)
|
| 196 |
+
|
| 197 |
+
# Ground truth may be one-hot: collapse to label map
|
| 198 |
+
if gts.ndim == 4:
|
| 199 |
+
gts = np.argmax(gts, axis=-1)
|
| 200 |
+
|
| 201 |
+
patient_fg_dice = defaultdict(list)
|
| 202 |
+
patient_error_rates = []
|
| 203 |
+
|
| 204 |
+
for i, slice_num in enumerate(slices):
|
| 205 |
+
gt_hw = gts[i]
|
| 206 |
+
pred_hw = preds[i]
|
| 207 |
+
prob_hw = probs[i]
|
| 208 |
+
|
| 209 |
+
mean_conf = float(np.mean(prob_hw))
|
| 210 |
+
m = compute_slice_metrics(gt_hw, pred_hw, num_classes,
|
| 211 |
+
class_names, mean_confidence=mean_conf)
|
| 212 |
+
|
| 213 |
+
row = {
|
| 214 |
+
'patient_id': patient_id,
|
| 215 |
+
'slice_num': int(slice_num),
|
| 216 |
+
'slice_id': f"{patient_id}_slice_{int(slice_num):03d}",
|
| 217 |
+
'error_rate': m['_summary']['error_rate'],
|
| 218 |
+
'wrong_pixels': m['_summary']['wrong_pixels'],
|
| 219 |
+
'mean_fg_dice': m['_summary']['mean_fg_dice'],
|
| 220 |
+
'min_fg_dice': m['_summary']['min_fg_dice'],
|
| 221 |
+
'mean_confidence': m['_summary']['mean_confidence'],
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
for cls in range(num_classes):
|
| 225 |
+
cname = class_names[cls]
|
| 226 |
+
cm = m[cname]
|
| 227 |
+
prefix = cname.lower().replace(' ', '_')
|
| 228 |
+
row[f'{prefix}_dice'] = cm['dice']
|
| 229 |
+
row[f'{prefix}_iou'] = cm['iou']
|
| 230 |
+
row[f'{prefix}_precision'] = cm['precision']
|
| 231 |
+
row[f'{prefix}_recall'] = cm['recall']
|
| 232 |
+
row[f'{prefix}_fp_rate'] = cm['fp_rate']
|
| 233 |
+
row[f'{prefix}_fn_rate'] = cm['fn_rate']
|
| 234 |
+
row[f'{prefix}_gt_px'] = cm['gt_pixels']
|
| 235 |
+
row[f'{prefix}_pred_px'] = cm['pred_pixels']
|
| 236 |
+
row[f'{prefix}_err_px'] = cm['error_pixels']
|
| 237 |
+
|
| 238 |
+
if cls > 0 and not np.isnan(cm['dice']):
|
| 239 |
+
patient_fg_dice[cname].append(cm['dice'])
|
| 240 |
+
|
| 241 |
+
patient_error_rates.append(m['_summary']['error_rate'])
|
| 242 |
+
slice_records.append(row)
|
| 243 |
+
|
| 244 |
+
# ── Patient summary ──
|
| 245 |
+
pat_row = {'patient_id': patient_id,
|
| 246 |
+
'n_slices': len(slices),
|
| 247 |
+
'mean_error_rate': float(np.mean(patient_error_rates))}
|
| 248 |
+
for cls in range(1, num_classes):
|
| 249 |
+
cname = class_names[cls]
|
| 250 |
+
vals = patient_fg_dice[cname]
|
| 251 |
+
prefix = cname.lower().replace(' ', '_')
|
| 252 |
+
pat_row[f'{prefix}_mean_dice'] = float(np.mean(vals)) if vals else np.nan
|
| 253 |
+
pat_row[f'{prefix}_std_dice'] = float(np.std(vals)) if vals else np.nan
|
| 254 |
+
pat_row[f'{prefix}_min_dice'] = float(np.min(vals)) if vals else np.nan
|
| 255 |
+
|
| 256 |
+
# Composite: mean of per-class mean dices (foreground only)
|
| 257 |
+
fg_means = [pat_row[f"{class_names[c].lower().replace(' ', '_')}_mean_dice"]
|
| 258 |
+
for c in range(1, num_classes)
|
| 259 |
+
if not np.isnan(pat_row.get(
|
| 260 |
+
f"{class_names[c].lower().replace(' ','_')}_mean_dice", np.nan))]
|
| 261 |
+
pat_row['composite_dice'] = float(np.mean(fg_means)) if fg_means else np.nan
|
| 262 |
+
|
| 263 |
+
patient_records.append(pat_row)
|
| 264 |
+
|
| 265 |
+
slice_df = pd.DataFrame(slice_records)
|
| 266 |
+
patient_df = pd.DataFrame(patient_records)
|
| 267 |
+
|
| 268 |
+
return slice_df, patient_df
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 272 |
+
# SECTION 3 — Composite difficulty score & ranking
|
| 273 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 274 |
+
|
| 275 |
+
def rank_slices(slice_df, class_names, num_classes,
|
| 276 |
+
fg_dice_weight=0.6, error_rate_weight=0.2,
|
| 277 |
+
confidence_weight=0.2):
|
| 278 |
+
"""
|
| 279 |
+
Add a `difficulty_score` column to slice_df (higher = harder).
|
| 280 |
+
|
| 281 |
+
Score = fg_dice_weight * (1 - mean_fg_dice)
|
| 282 |
+
+ error_rate_weight * error_rate
|
| 283 |
+
+ confidence_weight * (1 - mean_confidence)
|
| 284 |
+
|
| 285 |
+
NaN dice (class absent in GT) is neutral (0.5) so it doesn't
|
| 286 |
+
inflate difficulty for slices where the class just doesn't exist.
|
| 287 |
+
"""
|
| 288 |
+
df = slice_df.copy()
|
| 289 |
+
|
| 290 |
+
# Fill NaN mean_fg_dice with 0.5 for scoring (class not present → neutral)
|
| 291 |
+
fg_dice_filled = df['mean_fg_dice'].fillna(0.5)
|
| 292 |
+
conf_filled = df['mean_confidence'].fillna(0.5)
|
| 293 |
+
|
| 294 |
+
df['difficulty_score'] = (
|
| 295 |
+
fg_dice_weight * (1 - fg_dice_filled) +
|
| 296 |
+
error_rate_weight * df['error_rate'] +
|
| 297 |
+
confidence_weight * (1 - conf_filled)
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
df = df.sort_values('difficulty_score', ascending=False).reset_index(drop=True)
|
| 301 |
+
df['difficulty_rank'] = df.index + 1
|
| 302 |
+
|
| 303 |
+
return df
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def rank_patients(patient_df):
|
| 307 |
+
"""Sort patients from hardest to easiest (lowest composite dice first)."""
|
| 308 |
+
df = patient_df.copy()
|
| 309 |
+
df = df.sort_values('composite_dice', ascending=True).reset_index(drop=True)
|
| 310 |
+
df['difficulty_rank'] = df.index + 1
|
| 311 |
+
return df
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 315 |
+
# SECTION 4 — Visualization helpers
|
| 316 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 317 |
+
|
| 318 |
+
CLASS_COLORS_3 = ['black', '#2196F3', '#F44336'] # BG, Vent, WMH
|
| 319 |
+
CLASS_COLORS_4 = ['black', '#2196F3', '#4CAF50', '#F44336'] # BG, Vent, NormWMH, AbWMH
|
| 320 |
+
|
| 321 |
+
ERROR_CMAP = ListedColormap(['#1A1A1A', # correct background
|
| 322 |
+
'#FF5722', # FP (pred fg, gt bg)
|
| 323 |
+
'#03A9F4', # FN (gt fg, pred bg)
|
| 324 |
+
'#FFEB3B']) # class confusion
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def _get_class_cmap(num_classes):
|
| 328 |
+
colors = CLASS_COLORS_3 if num_classes == 3 else CLASS_COLORS_4
|
| 329 |
+
cmap = ListedColormap(colors)
|
| 330 |
+
norm = BoundaryNorm(range(num_classes + 1), num_classes)
|
| 331 |
+
return cmap, norm
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def _build_error_rgb(gt_hw, pred_hw, num_classes):
|
| 335 |
+
"""
|
| 336 |
+
Build a pixel-wise error classification map:
|
| 337 |
+
0 = correct
|
| 338 |
+
1 = false positive (model predicts fg, GT is bg)
|
| 339 |
+
2 = false negative (GT is fg, model predicts bg)
|
| 340 |
+
3 = class confusion (both fg but wrong class)
|
| 341 |
+
"""
|
| 342 |
+
gt_fg = gt_hw > 0
|
| 343 |
+
pred_fg = pred_hw > 0
|
| 344 |
+
|
| 345 |
+
err = np.zeros_like(gt_hw, dtype=np.uint8)
|
| 346 |
+
err[~gt_fg & pred_fg] = 1 # FP
|
| 347 |
+
err[gt_fg & ~pred_fg] = 2 # FN
|
| 348 |
+
err[gt_fg & pred_fg & (gt_hw != pred_hw)] = 3 # confusion
|
| 349 |
+
return err
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def _add_class_legend(ax, class_names, num_classes):
|
| 353 |
+
colors = CLASS_COLORS_3 if num_classes == 3 else CLASS_COLORS_4
|
| 354 |
+
patches = [mpatches.Patch(color=colors[i], label=class_names[i])
|
| 355 |
+
for i in range(num_classes)]
|
| 356 |
+
ax.legend(handles=patches, loc='lower right', fontsize=7,
|
| 357 |
+
framealpha=0.8, markerscale=0.8)
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 361 |
+
# SECTION 5 — Diagnostic slice visualization
|
| 362 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 363 |
+
|
| 364 |
+
def visualize_hard_slice(flair, gt_hw, pred_hw, prob_hw,
|
| 365 |
+
slice_metrics_row, class_names, num_classes,
|
| 366 |
+
save_path, rank=None):
|
| 367 |
+
"""
|
| 368 |
+
Create a rich 3-row diagnostic panel for a single hard slice.
|
| 369 |
+
|
| 370 |
+
Row 1 : FLAIR | GT mask | Predicted mask | Overlay (GT contour on FLAIR)
|
| 371 |
+
Row 2 : Confidence map | Error type map | GT vs Pred contour overlay
|
| 372 |
+
Row 3 : Per-class dice bar chart | FP/FN summary table
|
| 373 |
+
"""
|
| 374 |
+
cmap_cls, norm_cls = _get_class_cmap(num_classes)
|
| 375 |
+
err_map = _build_error_rgb(gt_hw, pred_hw, num_classes)
|
| 376 |
+
|
| 377 |
+
patient_id = slice_metrics_row.get('patient_id', '?')
|
| 378 |
+
slice_num = slice_metrics_row.get('slice_num', '?')
|
| 379 |
+
diff_score = slice_metrics_row.get('difficulty_score', float('nan'))
|
| 380 |
+
diff_rank = slice_metrics_row.get('difficulty_rank', rank)
|
| 381 |
+
mean_conf = slice_metrics_row.get('mean_confidence', float('nan'))
|
| 382 |
+
mean_fg_d = slice_metrics_row.get('mean_fg_dice', float('nan'))
|
| 383 |
+
|
| 384 |
+
fig = plt.figure(figsize=(20, 14))
|
| 385 |
+
fig.patch.set_facecolor('#0D0D0D')
|
| 386 |
+
title_str = (f"Patient: {patient_id} | Slice: {slice_num:03d} | "
|
| 387 |
+
f"Rank #{diff_rank} | Difficulty: {diff_score:.3f} | "
|
| 388 |
+
f"Mean FG Dice: {mean_fg_d:.3f} | Mean Conf: {mean_conf:.3f}")
|
| 389 |
+
fig.suptitle(title_str, color='white', fontsize=12, fontweight='bold', y=0.98)
|
| 390 |
+
|
| 391 |
+
gs = gridspec.GridSpec(3, 4, figure=fig,
|
| 392 |
+
hspace=0.35, wspace=0.25,
|
| 393 |
+
left=0.04, right=0.98,
|
| 394 |
+
top=0.93, bottom=0.04)
|
| 395 |
+
|
| 396 |
+
def styled_ax(pos):
|
| 397 |
+
ax = fig.add_subplot(pos)
|
| 398 |
+
ax.set_facecolor('#0D0D0D')
|
| 399 |
+
ax.tick_params(colors='white')
|
| 400 |
+
for spine in ax.spines.values():
|
| 401 |
+
spine.set_edgecolor('#444')
|
| 402 |
+
return ax
|
| 403 |
+
|
| 404 |
+
# ── Row 0 ──────────────────────────────────────────────────────────────
|
| 405 |
+
ax00 = styled_ax(gs[0, 0])
|
| 406 |
+
ax00.imshow(flair, cmap='gray', vmin=flair.min(), vmax=flair.max())
|
| 407 |
+
ax00.set_title('FLAIR', color='white', fontsize=10)
|
| 408 |
+
ax00.axis('off')
|
| 409 |
+
|
| 410 |
+
ax01 = styled_ax(gs[0, 1])
|
| 411 |
+
ax01.imshow(gt_hw, cmap=cmap_cls, norm=norm_cls, interpolation='nearest')
|
| 412 |
+
ax01.set_title('Ground Truth', color='white', fontsize=10)
|
| 413 |
+
ax01.axis('off')
|
| 414 |
+
_add_class_legend(ax01, class_names, num_classes)
|
| 415 |
+
|
| 416 |
+
ax02 = styled_ax(gs[0, 2])
|
| 417 |
+
ax02.imshow(pred_hw, cmap=cmap_cls, norm=norm_cls, interpolation='nearest')
|
| 418 |
+
ax02.set_title('Prediction', color='white', fontsize=10)
|
| 419 |
+
ax02.axis('off')
|
| 420 |
+
_add_class_legend(ax02, class_names, num_classes)
|
| 421 |
+
|
| 422 |
+
# GT contour overlay on FLAIR
|
| 423 |
+
ax03 = styled_ax(gs[0, 3])
|
| 424 |
+
ax03.imshow(flair, cmap='gray', vmin=flair.min(), vmax=flair.max())
|
| 425 |
+
colors_cls = CLASS_COLORS_3 if num_classes == 3 else CLASS_COLORS_4
|
| 426 |
+
for cls in range(1, num_classes):
|
| 427 |
+
gt_bin = (gt_hw == cls).astype(np.uint8)
|
| 428 |
+
pred_bin = (pred_hw == cls).astype(np.uint8)
|
| 429 |
+
if gt_bin.any():
|
| 430 |
+
ax03.contour(gt_bin, levels=[0.5], colors=[colors_cls[cls]],
|
| 431 |
+
linewidths=1.5, linestyles='solid')
|
| 432 |
+
if pred_bin.any():
|
| 433 |
+
ax03.contour(pred_bin, levels=[0.5], colors=[colors_cls[cls]],
|
| 434 |
+
linewidths=1.2, linestyles='dashed')
|
| 435 |
+
gt_patch = mpatches.Patch(color='white', linestyle='solid', label='GT (solid)')
|
| 436 |
+
pred_patch = mpatches.Patch(color='white', linestyle='dashed', label='Pred (dashed)')
|
| 437 |
+
ax03.legend(handles=[gt_patch, pred_patch], loc='lower right',
|
| 438 |
+
fontsize=7, framealpha=0.7)
|
| 439 |
+
ax03.set_title('GT vs Pred Contours', color='white', fontsize=10)
|
| 440 |
+
ax03.axis('off')
|
| 441 |
+
|
| 442 |
+
# ── Row 1 ──────────────────────────────────────────────────────────────
|
| 443 |
+
ax10 = styled_ax(gs[1, 0])
|
| 444 |
+
im_conf = ax10.imshow(prob_hw, cmap='plasma', vmin=0, vmax=1)
|
| 445 |
+
plt.colorbar(im_conf, ax=ax10, fraction=0.046, pad=0.04).ax.yaxis.set_tick_params(color='white')
|
| 446 |
+
ax10.set_title('Confidence Map', color='white', fontsize=10)
|
| 447 |
+
ax10.axis('off')
|
| 448 |
+
|
| 449 |
+
# Low-confidence overlay on FLAIR
|
| 450 |
+
ax11 = styled_ax(gs[1, 1])
|
| 451 |
+
ax11.imshow(flair, cmap='gray')
|
| 452 |
+
low_conf_mask = prob_hw < 0.5
|
| 453 |
+
overlay = np.zeros((*flair.shape, 4))
|
| 454 |
+
overlay[low_conf_mask] = [1, 0.3, 0, 0.55] # orange-red for uncertain regions
|
| 455 |
+
ax11.imshow(overlay)
|
| 456 |
+
ax11.set_title('Low-Confidence Regions (<0.5)', color='white', fontsize=10)
|
| 457 |
+
ax11.axis('off')
|
| 458 |
+
|
| 459 |
+
ax12 = styled_ax(gs[1, 2])
|
| 460 |
+
err_colors = ['#1A1A1A', '#FF5722', '#03A9F4', '#FFEB3B']
|
| 461 |
+
err_cmap = ListedColormap(err_colors)
|
| 462 |
+
err_norm = BoundaryNorm([0, 1, 2, 3, 4], 4)
|
| 463 |
+
ax12.imshow(err_map, cmap=err_cmap, norm=err_norm, interpolation='nearest')
|
| 464 |
+
patches_err = [
|
| 465 |
+
mpatches.Patch(color='#1A1A1A', label='Correct'),
|
| 466 |
+
mpatches.Patch(color='#FF5722', label='False Positive'),
|
| 467 |
+
mpatches.Patch(color='#03A9F4', label='False Negative'),
|
| 468 |
+
mpatches.Patch(color='#FFEB3B', label='Class Confusion'),
|
| 469 |
+
]
|
| 470 |
+
ax12.legend(handles=patches_err, loc='lower right', fontsize=6.5, framealpha=0.8)
|
| 471 |
+
ax12.set_title('Error Type Map', color='white', fontsize=10)
|
| 472 |
+
ax12.axis('off')
|
| 473 |
+
|
| 474 |
+
# FLAIR + error overlay
|
| 475 |
+
ax13 = styled_ax(gs[1, 3])
|
| 476 |
+
flair_rgb = np.stack([flair] * 3, axis=-1)
|
| 477 |
+
# Normalise 0-1
|
| 478 |
+
flair_rgb = (flair_rgb - flair_rgb.min()) / (flair_rgb.max() - flair_rgb.min() + 1e-7)
|
| 479 |
+
err_overlay = flair_rgb.copy()
|
| 480 |
+
err_overlay[err_map == 1] = [1.0, 0.34, 0.13] # FP
|
| 481 |
+
err_overlay[err_map == 2] = [0.01, 0.66, 0.96] # FN
|
| 482 |
+
err_overlay[err_map == 3] = [1.0, 0.92, 0.23] # confusion
|
| 483 |
+
ax13.imshow(err_overlay)
|
| 484 |
+
ax13.set_title('FLAIR + Error Overlay', color='white', fontsize=10)
|
| 485 |
+
ax13.axis('off')
|
| 486 |
+
|
| 487 |
+
# ── Row 2: metrics ─────────────────────────────────────────────────────
|
| 488 |
+
ax20 = styled_ax(gs[2, 0:2])
|
| 489 |
+
ax20.set_facecolor('#111')
|
| 490 |
+
|
| 491 |
+
bar_labels = []
|
| 492 |
+
bar_dice = []
|
| 493 |
+
bar_colors = []
|
| 494 |
+
for cls in range(1, num_classes):
|
| 495 |
+
cname = class_names[cls]
|
| 496 |
+
prefix = cname.lower().replace(' ', '_')
|
| 497 |
+
d = slice_metrics_row.get(f'{prefix}_dice', np.nan)
|
| 498 |
+
bar_labels.append(cname)
|
| 499 |
+
bar_dice.append(d if not np.isnan(d) else 0)
|
| 500 |
+
bar_colors.append(colors_cls[cls])
|
| 501 |
+
|
| 502 |
+
x = np.arange(len(bar_labels))
|
| 503 |
+
bars = ax20.bar(x, bar_dice, color=bar_colors, edgecolor='white',
|
| 504 |
+
linewidth=0.8, width=0.5)
|
| 505 |
+
ax20.axhline(0.5, color='red', linestyle='--', linewidth=1, label='Threshold 0.5')
|
| 506 |
+
ax20.axhline(0.8, color='yellow', linestyle='--', linewidth=1, label='Good 0.8')
|
| 507 |
+
ax20.set_xticks(x)
|
| 508 |
+
ax20.set_xticklabels(bar_labels, color='white', fontsize=9)
|
| 509 |
+
ax20.set_ylim(0, 1.05)
|
| 510 |
+
ax20.set_ylabel('Dice Score', color='white', fontsize=9)
|
| 511 |
+
ax20.set_title('Per-Class Dice', color='white', fontsize=10)
|
| 512 |
+
ax20.tick_params(axis='y', colors='white')
|
| 513 |
+
ax20.legend(fontsize=7, labelcolor='white', framealpha=0.3)
|
| 514 |
+
for bar, val in zip(bars, bar_dice):
|
| 515 |
+
ax20.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.02,
|
| 516 |
+
f'{val:.3f}', ha='center', color='white', fontsize=9)
|
| 517 |
+
|
| 518 |
+
# Table: per-class FP/FN/precision/recall
|
| 519 |
+
ax21 = styled_ax(gs[2, 2:4])
|
| 520 |
+
ax21.axis('off')
|
| 521 |
+
|
| 522 |
+
col_labels = ['Class', 'Dice', 'Prec', 'Recall', 'FP rate', 'FN rate',
|
| 523 |
+
'GT px', 'Pred px']
|
| 524 |
+
table_data = []
|
| 525 |
+
for cls in range(1, num_classes):
|
| 526 |
+
cname = class_names[cls]
|
| 527 |
+
prefix = cname.lower().replace(' ', '_')
|
| 528 |
+
def _g(k):
|
| 529 |
+
v = slice_metrics_row.get(f'{prefix}_{k}', np.nan)
|
| 530 |
+
return f'{v:.3f}' if not np.isnan(v) else 'N/A'
|
| 531 |
+
table_data.append([
|
| 532 |
+
cname,
|
| 533 |
+
_g('dice'), _g('precision'), _g('recall'),
|
| 534 |
+
_g('fp_rate'), _g('fn_rate'),
|
| 535 |
+
str(int(slice_metrics_row.get(f'{prefix}_gt_px', 0))),
|
| 536 |
+
str(int(slice_metrics_row.get(f'{prefix}_pred_px', 0))),
|
| 537 |
+
])
|
| 538 |
+
|
| 539 |
+
tbl = ax21.table(cellText=table_data, colLabels=col_labels,
|
| 540 |
+
cellLoc='center', loc='center')
|
| 541 |
+
tbl.auto_set_font_size(False)
|
| 542 |
+
tbl.set_fontsize(8)
|
| 543 |
+
tbl.scale(1, 1.6)
|
| 544 |
+
for (r, c), cell in tbl.get_celld().items():
|
| 545 |
+
cell.set_edgecolor('#444')
|
| 546 |
+
if r == 0:
|
| 547 |
+
cell.set_facecolor('#2C2C2C')
|
| 548 |
+
cell.set_text_props(color='white', fontweight='bold')
|
| 549 |
+
else:
|
| 550 |
+
cell.set_facecolor('#1A1A1A')
|
| 551 |
+
cell.set_text_props(color='white')
|
| 552 |
+
ax21.set_title('Per-Class Metrics Summary', color='white', fontsize=10, pad=8)
|
| 553 |
+
|
| 554 |
+
plt.savefig(save_path, dpi=130, bbox_inches='tight',
|
| 555 |
+
facecolor=fig.get_facecolor())
|
| 556 |
+
plt.close(fig)
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 560 |
+
# SECTION 6 — Patient-level summary visualization
|
| 561 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 562 |
+
|
| 563 |
+
def visualize_patient_summary(patient_id, patient_data, slice_df_patient,
|
| 564 |
+
class_names, num_classes, save_path):
|
| 565 |
+
"""
|
| 566 |
+
One-page summary for a single patient showing:
|
| 567 |
+
- Dice scores across all slices (line plot per class)
|
| 568 |
+
- Confidence vs. error rate scatter
|
| 569 |
+
- Per-slice FP / FN bar chart
|
| 570 |
+
- Overall dice distribution box plots
|
| 571 |
+
"""
|
| 572 |
+
order = np.argsort(patient_data['slice_indices'])
|
| 573 |
+
slices = np.array(patient_data['slice_indices'])[order]
|
| 574 |
+
n_slices = len(slices)
|
| 575 |
+
|
| 576 |
+
fig, axes = plt.subplots(2, 2, figsize=(18, 10))
|
| 577 |
+
fig.patch.set_facecolor('#0D0D0D')
|
| 578 |
+
fig.suptitle(f'Patient Summary | ID: {patient_id} | {n_slices} slices',
|
| 579 |
+
color='white', fontsize=13, fontweight='bold')
|
| 580 |
+
|
| 581 |
+
colors_cls = CLASS_COLORS_3 if num_classes == 3 else CLASS_COLORS_4
|
| 582 |
+
|
| 583 |
+
df = slice_df_patient.sort_values('slice_num').reset_index(drop=True)
|
| 584 |
+
|
| 585 |
+
# ── Plot 1: Per-slice Dice per class ──────────────────────────────────
|
| 586 |
+
ax = axes[0, 0]
|
| 587 |
+
ax.set_facecolor('#111')
|
| 588 |
+
for cls in range(1, num_classes):
|
| 589 |
+
cname = class_names[cls]
|
| 590 |
+
prefix = cname.lower().replace(' ', '_')
|
| 591 |
+
col = f'{prefix}_dice'
|
| 592 |
+
if col in df.columns:
|
| 593 |
+
valid = df[col].notna()
|
| 594 |
+
ax.plot(df.loc[valid, 'slice_num'], df.loc[valid, col],
|
| 595 |
+
color=colors_cls[cls], linewidth=1.5,
|
| 596 |
+
marker='o', markersize=3, label=cname)
|
| 597 |
+
ax.axhline(0.5, color='red', linestyle='--', linewidth=0.8, alpha=0.7)
|
| 598 |
+
ax.axhline(0.8, color='yellow', linestyle='--', linewidth=0.8, alpha=0.7)
|
| 599 |
+
ax.set_xlabel('Slice Number', color='white')
|
| 600 |
+
ax.set_ylabel('Dice Score', color='white')
|
| 601 |
+
ax.set_title('Per-Slice Dice by Class', color='white', fontsize=10)
|
| 602 |
+
ax.legend(fontsize=8, labelcolor='white', framealpha=0.3)
|
| 603 |
+
ax.tick_params(colors='white')
|
| 604 |
+
for spine in ax.spines.values():
|
| 605 |
+
spine.set_edgecolor('#444')
|
| 606 |
+
ax.set_ylim(0, 1.05)
|
| 607 |
+
|
| 608 |
+
# ── Plot 2: Confidence vs Error rate scatter ───────────────────────────
|
| 609 |
+
ax = axes[0, 1]
|
| 610 |
+
ax.set_facecolor('#111')
|
| 611 |
+
sc = ax.scatter(df['mean_confidence'], df['error_rate'],
|
| 612 |
+
c=df['mean_fg_dice'].fillna(0.5),
|
| 613 |
+
cmap='RdYlGn', vmin=0, vmax=1,
|
| 614 |
+
s=50, edgecolors='white', linewidths=0.3, alpha=0.85)
|
| 615 |
+
cbar = plt.colorbar(sc, ax=ax)
|
| 616 |
+
cbar.set_label('Mean FG Dice', color='white')
|
| 617 |
+
cbar.ax.yaxis.set_tick_params(color='white')
|
| 618 |
+
plt.setp(plt.getp(cbar.ax.axes, 'yticklabels'), color='white')
|
| 619 |
+
ax.set_xlabel('Mean Confidence', color='white')
|
| 620 |
+
ax.set_ylabel('Pixel Error Rate', color='white')
|
| 621 |
+
ax.set_title('Confidence vs Error Rate\n(colour = Mean FG Dice)',
|
| 622 |
+
color='white', fontsize=10)
|
| 623 |
+
ax.tick_params(colors='white')
|
| 624 |
+
for spine in ax.spines.values():
|
| 625 |
+
spine.set_edgecolor('#444')
|
| 626 |
+
|
| 627 |
+
# Annotate worst 3 slices
|
| 628 |
+
worst3 = df.nlargest(3, 'difficulty_score') if 'difficulty_score' in df.columns \
|
| 629 |
+
else df.nlargest(3, 'error_rate')
|
| 630 |
+
for _, row in worst3.iterrows():
|
| 631 |
+
ax.annotate(f"sl{int(row['slice_num']):03d}",
|
| 632 |
+
(row['mean_confidence'], row['error_rate']),
|
| 633 |
+
textcoords="offset points", xytext=(5, 5),
|
| 634 |
+
fontsize=7, color='white')
|
| 635 |
+
|
| 636 |
+
# ── Plot 3: FP / FN pixel rates per slice ─────────────────────────────
|
| 637 |
+
ax = axes[1, 0]
|
| 638 |
+
ax.set_facecolor('#111')
|
| 639 |
+
x = df['slice_num'].values
|
| 640 |
+
# Use WMH class (last foreground class) as primary interest
|
| 641 |
+
cls_main = num_classes - 1
|
| 642 |
+
prefix_m = class_names[cls_main].lower().replace(' ', '_')
|
| 643 |
+
fp_col = f'{prefix_m}_fp_rate'
|
| 644 |
+
fn_col = f'{prefix_m}_fn_rate'
|
| 645 |
+
|
| 646 |
+
if fp_col in df.columns and fn_col in df.columns:
|
| 647 |
+
width = 0.4
|
| 648 |
+
ax.bar(x - width/2, df[fp_col].fillna(0), width=width,
|
| 649 |
+
color='#FF5722', alpha=0.8, label='FP Rate')
|
| 650 |
+
ax.bar(x + width/2, df[fn_col].fillna(0), width=width,
|
| 651 |
+
color='#03A9F4', alpha=0.8, label='FN Rate')
|
| 652 |
+
ax.set_xlabel('Slice Number', color='white')
|
| 653 |
+
ax.set_ylabel('Rate', color='white')
|
| 654 |
+
ax.set_title(f'FP / FN Rate per Slice [{class_names[cls_main]}]',
|
| 655 |
+
color='white', fontsize=10)
|
| 656 |
+
ax.legend(fontsize=8, labelcolor='white', framealpha=0.3)
|
| 657 |
+
ax.tick_params(colors='white')
|
| 658 |
+
for spine in ax.spines.values():
|
| 659 |
+
spine.set_edgecolor('#444')
|
| 660 |
+
|
| 661 |
+
# ── Plot 4: Dice distribution box plots ───────────────────────────────
|
| 662 |
+
ax = axes[1, 1]
|
| 663 |
+
ax.set_facecolor('#111')
|
| 664 |
+
box_data = []
|
| 665 |
+
box_labels = []
|
| 666 |
+
box_colors = []
|
| 667 |
+
for cls in range(1, num_classes):
|
| 668 |
+
cname = class_names[cls]
|
| 669 |
+
prefix = cname.lower().replace(' ', '_')
|
| 670 |
+
col = f'{prefix}_dice'
|
| 671 |
+
vals = df[col].dropna().values if col in df.columns else np.array([])
|
| 672 |
+
box_data.append(vals)
|
| 673 |
+
box_labels.append(cname)
|
| 674 |
+
box_colors.append(colors_cls[cls])
|
| 675 |
+
|
| 676 |
+
bp = ax.boxplot(box_data, patch_artist=True,
|
| 677 |
+
medianprops=dict(color='white', linewidth=2))
|
| 678 |
+
for patch, color in zip(bp['boxes'], box_colors):
|
| 679 |
+
patch.set_facecolor(color)
|
| 680 |
+
patch.set_alpha(0.7)
|
| 681 |
+
for element in ['whiskers', 'caps', 'fliers']:
|
| 682 |
+
for item in bp[element]:
|
| 683 |
+
item.set_color('white')
|
| 684 |
+
|
| 685 |
+
ax.set_xticklabels(box_labels, color='white')
|
| 686 |
+
ax.set_ylabel('Dice Score', color='white')
|
| 687 |
+
ax.set_title('Dice Score Distribution per Class', color='white', fontsize=10)
|
| 688 |
+
ax.axhline(0.5, color='red', linestyle='--', linewidth=0.8, alpha=0.7)
|
| 689 |
+
ax.axhline(0.8, color='yellow', linestyle='--', linewidth=0.8, alpha=0.7)
|
| 690 |
+
ax.tick_params(colors='white')
|
| 691 |
+
for spine in ax.spines.values():
|
| 692 |
+
spine.set_edgecolor('#444')
|
| 693 |
+
ax.set_ylim(0, 1.05)
|
| 694 |
+
|
| 695 |
+
plt.tight_layout(rect=[0, 0, 1, 0.95])
|
| 696 |
+
plt.savefig(save_path, dpi=120, bbox_inches='tight',
|
| 697 |
+
facecolor=fig.get_facecolor())
|
| 698 |
+
plt.close(fig)
|
| 699 |
+
|
| 700 |
+
|
| 701 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 702 |
+
# SECTION 7 — Dataset-level overview visualizations
|
| 703 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 704 |
+
|
| 705 |
+
def visualize_dataset_overview(slice_df, patient_df, class_names,
|
| 706 |
+
num_classes, save_dir):
|
| 707 |
+
"""
|
| 708 |
+
Global overview plots saved to save_dir/overview/:
|
| 709 |
+
1. Dice distribution across all slices (violin per class)
|
| 710 |
+
2. Patient ranking bar chart (composite dice)
|
| 711 |
+
3. Error rate histogram
|
| 712 |
+
4. Confidence vs dice scatter (all slices)
|
| 713 |
+
5. Difficulty score distribution
|
| 714 |
+
"""
|
| 715 |
+
overview_dir = Path(save_dir) / 'overview'
|
| 716 |
+
overview_dir.mkdir(parents=True, exist_ok=True)
|
| 717 |
+
|
| 718 |
+
colors_cls = CLASS_COLORS_3 if num_classes == 3 else CLASS_COLORS_4
|
| 719 |
+
|
| 720 |
+
# ── 1. Dice violin ────────────────────────────────────────────────────
|
| 721 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
| 722 |
+
fig.patch.set_facecolor('#0D0D0D')
|
| 723 |
+
ax.set_facecolor('#111')
|
| 724 |
+
|
| 725 |
+
violin_data = []
|
| 726 |
+
violin_labels = []
|
| 727 |
+
for cls in range(1, num_classes):
|
| 728 |
+
cname = class_names[cls]
|
| 729 |
+
prefix = cname.lower().replace(' ', '_')
|
| 730 |
+
col = f'{prefix}_dice'
|
| 731 |
+
vals = slice_df[col].dropna().values if col in slice_df.columns else np.array([])
|
| 732 |
+
violin_data.append(vals)
|
| 733 |
+
violin_labels.append(cname)
|
| 734 |
+
|
| 735 |
+
parts = ax.violinplot(violin_data, showmedians=True, showextrema=True)
|
| 736 |
+
for i, (pc, color) in enumerate(zip(parts['bodies'],
|
| 737 |
+
[colors_cls[c] for c in range(1, num_classes)])):
|
| 738 |
+
pc.set_facecolor(color)
|
| 739 |
+
pc.set_alpha(0.7)
|
| 740 |
+
parts['cmedians'].set_colors('white')
|
| 741 |
+
parts['cmaxes'].set_colors('#aaa')
|
| 742 |
+
parts['cmins'].set_colors('#aaa')
|
| 743 |
+
parts['cbars'].set_colors('#aaa')
|
| 744 |
+
|
| 745 |
+
ax.set_xticks(range(1, len(violin_labels) + 1))
|
| 746 |
+
ax.set_xticklabels(violin_labels, color='white')
|
| 747 |
+
ax.axhline(0.5, color='red', linestyle='--', linewidth=0.9, label='0.5 threshold')
|
| 748 |
+
ax.axhline(0.8, color='yellow', linestyle='--', linewidth=0.9, label='0.8 target')
|
| 749 |
+
ax.set_ylabel('Dice Score', color='white')
|
| 750 |
+
ax.set_title('Dice Distribution — All Slices', color='white', fontsize=12)
|
| 751 |
+
ax.tick_params(colors='white')
|
| 752 |
+
ax.legend(fontsize=8, labelcolor='white', framealpha=0.3)
|
| 753 |
+
for spine in ax.spines.values():
|
| 754 |
+
spine.set_edgecolor('#444')
|
| 755 |
+
ax.set_ylim(0, 1.05)
|
| 756 |
+
|
| 757 |
+
plt.tight_layout()
|
| 758 |
+
plt.savefig(overview_dir / 'dice_violin_all_slices.png', dpi=130,
|
| 759 |
+
bbox_inches='tight', facecolor=fig.get_facecolor())
|
| 760 |
+
plt.close(fig)
|
| 761 |
+
|
| 762 |
+
# ── 2. Patient ranking bar chart ──────────────────────────────────────
|
| 763 |
+
pat_sorted = patient_df.sort_values('composite_dice').reset_index(drop=True)
|
| 764 |
+
n_patients = len(pat_sorted)
|
| 765 |
+
|
| 766 |
+
fig, ax = plt.subplots(figsize=(max(12, n_patients * 0.6), 5))
|
| 767 |
+
fig.patch.set_facecolor('#0D0D0D')
|
| 768 |
+
ax.set_facecolor('#111')
|
| 769 |
+
|
| 770 |
+
bar_colors = ['#F44336' if v < 0.5 else '#FFC107' if v < 0.7 else '#4CAF50'
|
| 771 |
+
for v in pat_sorted['composite_dice'].fillna(0)]
|
| 772 |
+
ax.bar(range(n_patients), pat_sorted['composite_dice'].fillna(0),
|
| 773 |
+
color=bar_colors, edgecolor='#333', linewidth=0.5)
|
| 774 |
+
ax.set_xticks(range(n_patients))
|
| 775 |
+
ax.set_xticklabels(pat_sorted['patient_id'], rotation=75,
|
| 776 |
+
ha='right', color='white', fontsize=7)
|
| 777 |
+
ax.axhline(0.5, color='red', linestyle='--', linewidth=0.9)
|
| 778 |
+
ax.axhline(0.7, color='orange', linestyle='--', linewidth=0.9)
|
| 779 |
+
ax.axhline(0.8, color='yellow', linestyle='--', linewidth=0.9)
|
| 780 |
+
ax.set_ylabel('Composite Dice (mean FG classes)', color='white')
|
| 781 |
+
ax.set_title('Patient Ranking — Composite Dice (worst → best)',
|
| 782 |
+
color='white', fontsize=12)
|
| 783 |
+
ax.tick_params(colors='white')
|
| 784 |
+
for spine in ax.spines.values():
|
| 785 |
+
spine.set_edgecolor('#444')
|
| 786 |
+
ax.set_ylim(0, 1.05)
|
| 787 |
+
|
| 788 |
+
red_p = mpatches.Patch(color='#F44336', label='< 0.5 (critical)')
|
| 789 |
+
orange_p = mpatches.Patch(color='#FFC107', label='0.5–0.7 (poor)')
|
| 790 |
+
green_p = mpatches.Patch(color='#4CAF50', label='≥ 0.7 (acceptable)')
|
| 791 |
+
ax.legend(handles=[red_p, orange_p, green_p],
|
| 792 |
+
fontsize=8, labelcolor='white', framealpha=0.3)
|
| 793 |
+
|
| 794 |
+
plt.tight_layout()
|
| 795 |
+
plt.savefig(overview_dir / 'patient_ranking.png', dpi=130,
|
| 796 |
+
bbox_inches='tight', facecolor=fig.get_facecolor())
|
| 797 |
+
plt.close(fig)
|
| 798 |
+
|
| 799 |
+
# ── 3. Error rate histogram ────────────────────────────────────────────
|
| 800 |
+
fig, ax = plt.subplots(figsize=(9, 5))
|
| 801 |
+
fig.patch.set_facecolor('#0D0D0D')
|
| 802 |
+
ax.set_facecolor('#111')
|
| 803 |
+
ax.hist(slice_df['error_rate'].dropna(), bins=40, color='#9C27B0',
|
| 804 |
+
edgecolor='white', linewidth=0.3, alpha=0.85)
|
| 805 |
+
ax.set_xlabel('Pixel Error Rate per Slice', color='white')
|
| 806 |
+
ax.set_ylabel('Count', color='white')
|
| 807 |
+
ax.set_title('Pixel Error Rate Distribution — All Slices', color='white', fontsize=12)
|
| 808 |
+
ax.tick_params(colors='white')
|
| 809 |
+
for spine in ax.spines.values():
|
| 810 |
+
spine.set_edgecolor('#444')
|
| 811 |
+
plt.tight_layout()
|
| 812 |
+
plt.savefig(overview_dir / 'error_rate_histogram.png', dpi=130,
|
| 813 |
+
bbox_inches='tight', facecolor=fig.get_facecolor())
|
| 814 |
+
plt.close(fig)
|
| 815 |
+
|
| 816 |
+
# ── 4. Confidence vs mean FG Dice scatter ─────────────────────────────
|
| 817 |
+
fig, ax = plt.subplots(figsize=(9, 6))
|
| 818 |
+
fig.patch.set_facecolor('#0D0D0D')
|
| 819 |
+
ax.set_facecolor('#111')
|
| 820 |
+
sc = ax.scatter(slice_df['mean_confidence'], slice_df['mean_fg_dice'].fillna(0),
|
| 821 |
+
c=slice_df['error_rate'], cmap='RdYlGn_r',
|
| 822 |
+
vmin=0, vmax=0.3, s=10, alpha=0.6)
|
| 823 |
+
cbar = plt.colorbar(sc, ax=ax)
|
| 824 |
+
cbar.set_label('Pixel Error Rate', color='white')
|
| 825 |
+
cbar.ax.yaxis.set_tick_params(color='white')
|
| 826 |
+
plt.setp(plt.getp(cbar.ax.axes, 'yticklabels'), color='white')
|
| 827 |
+
ax.set_xlabel('Mean Softmax Confidence', color='white')
|
| 828 |
+
ax.set_ylabel('Mean FG Dice', color='white')
|
| 829 |
+
ax.set_title('Confidence vs FG Dice — All Slices', color='white', fontsize=12)
|
| 830 |
+
ax.tick_params(colors='white')
|
| 831 |
+
for spine in ax.spines.values():
|
| 832 |
+
spine.set_edgecolor('#444')
|
| 833 |
+
plt.tight_layout()
|
| 834 |
+
plt.savefig(overview_dir / 'confidence_vs_dice_scatter.png', dpi=130,
|
| 835 |
+
bbox_inches='tight', facecolor=fig.get_facecolor())
|
| 836 |
+
plt.close(fig)
|
| 837 |
+
|
| 838 |
+
# ── 5. Difficulty score distribution ──────────────────────────────────
|
| 839 |
+
if 'difficulty_score' in slice_df.columns:
|
| 840 |
+
fig, ax = plt.subplots(figsize=(9, 5))
|
| 841 |
+
fig.patch.set_facecolor('#0D0D0D')
|
| 842 |
+
ax.set_facecolor('#111')
|
| 843 |
+
ax.hist(slice_df['difficulty_score'].dropna(), bins=40,
|
| 844 |
+
color='#FF9800', edgecolor='white', linewidth=0.3, alpha=0.85)
|
| 845 |
+
ax.set_xlabel('Difficulty Score', color='white')
|
| 846 |
+
ax.set_ylabel('Count', color='white')
|
| 847 |
+
ax.set_title('Difficulty Score Distribution — All Slices', color='white', fontsize=12)
|
| 848 |
+
ax.tick_params(colors='white')
|
| 849 |
+
for spine in ax.spines.values():
|
| 850 |
+
spine.set_edgecolor('#444')
|
| 851 |
+
plt.tight_layout()
|
| 852 |
+
plt.savefig(overview_dir / 'difficulty_score_histogram.png', dpi=130,
|
| 853 |
+
bbox_inches='tight', facecolor=fig.get_facecolor())
|
| 854 |
+
plt.close(fig)
|
| 855 |
+
|
| 856 |
+
print(f" ✅ Overview plots saved to: {overview_dir}")
|
| 857 |
+
|
| 858 |
+
|
| 859 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 860 |
+
# SECTION 8 — Main entry point: run_error_analysis
|
| 861 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 862 |
+
|
| 863 |
+
def run_error_analysis(results, config,
|
| 864 |
+
top_n_slices=30,
|
| 865 |
+
top_n_patients=10,
|
| 866 |
+
fg_dice_weight=0.6,
|
| 867 |
+
error_rate_weight=0.2,
|
| 868 |
+
confidence_weight=0.2):
|
| 869 |
+
"""
|
| 870 |
+
Full pipeline: build tables → rank → save CSVs → generate visualizations.
|
| 871 |
+
|
| 872 |
+
Call after run_inference():
|
| 873 |
+
results = run_inference(config)
|
| 874 |
+
run_error_analysis(results, config)
|
| 875 |
+
|
| 876 |
+
Parameters
|
| 877 |
+
----------
|
| 878 |
+
results : dict — returned by run_inference()
|
| 879 |
+
config : InferenceConfig
|
| 880 |
+
top_n_slices : int — how many hardest slices to visualize individually
|
| 881 |
+
top_n_patients : int — how many hardest patients to get summary plots
|
| 882 |
+
fg_dice_weight, error_rate_weight, confidence_weight : floats for ranking
|
| 883 |
+
"""
|
| 884 |
+
patient_results = results['patients_results']
|
| 885 |
+
class_names = config.class_names
|
| 886 |
+
num_classes = config.num_classes
|
| 887 |
+
|
| 888 |
+
# Output sub-directories
|
| 889 |
+
error_dir = config.inference_dir / 'error_analysis'
|
| 890 |
+
hard_slices_dir = error_dir / 'hard_slices'
|
| 891 |
+
patient_summaries_dir = error_dir / 'patient_summaries'
|
| 892 |
+
tables_dir = error_dir / 'tables'
|
| 893 |
+
|
| 894 |
+
for d in [hard_slices_dir, patient_summaries_dir, tables_dir]:
|
| 895 |
+
d.mkdir(parents=True, exist_ok=True)
|
| 896 |
+
|
| 897 |
+
print("\n" + "=" * 70)
|
| 898 |
+
print("ERROR ANALYSIS — Building slice & patient tables")
|
| 899 |
+
print("=" * 70)
|
| 900 |
+
|
| 901 |
+
# ── Step 1: build tables ──────────────────────────────────────────────
|
| 902 |
+
slice_df, patient_df = build_error_tables(patient_results, num_classes, class_names)
|
| 903 |
+
|
| 904 |
+
# ── Step 2: rank ──────────────────────────────────────────────────────
|
| 905 |
+
slice_df = rank_slices(slice_df, class_names, num_classes,
|
| 906 |
+
fg_dice_weight, error_rate_weight, confidence_weight)
|
| 907 |
+
patient_df = rank_patients(patient_df)
|
| 908 |
+
|
| 909 |
+
# ── Step 3: save CSVs ─────────────────────────────────────────────────
|
| 910 |
+
slice_csv = tables_dir / 'slice_difficulty_ranking.csv'
|
| 911 |
+
patient_csv = tables_dir / 'patient_difficulty_ranking.csv'
|
| 912 |
+
slice_df.to_csv(slice_csv, index=False)
|
| 913 |
+
patient_df.to_csv(patient_csv, index=False)
|
| 914 |
+
print(f" ✅ Slice table → {slice_csv}")
|
| 915 |
+
print(f" ✅ Patient table → {patient_csv}")
|
| 916 |
+
|
| 917 |
+
# ── Step 4: dataset overview plots ────────────────────────────────────
|
| 918 |
+
print("\nGenerating dataset overview plots...")
|
| 919 |
+
visualize_dataset_overview(slice_df, patient_df, class_names,
|
| 920 |
+
num_classes, error_dir)
|
| 921 |
+
|
| 922 |
+
# ── Step 5: hard slice visualizations ────────────────────────────────
|
| 923 |
+
print(f"\nVisualizing top-{top_n_slices} hardest slices...")
|
| 924 |
+
hard_slices = slice_df.head(top_n_slices)
|
| 925 |
+
|
| 926 |
+
for _, row in tqdm(hard_slices.iterrows(),
|
| 927 |
+
total=len(hard_slices), desc="Hard slice panels"):
|
| 928 |
+
patient_id = row['patient_id']
|
| 929 |
+
slice_num = int(row['slice_num'])
|
| 930 |
+
|
| 931 |
+
data = patient_results[patient_id]
|
| 932 |
+
order = np.argsort(data['slice_indices'])
|
| 933 |
+
slices_sorted = np.array(data['slice_indices'])[order]
|
| 934 |
+
|
| 935 |
+
# Find position of this slice
|
| 936 |
+
pos = np.where(slices_sorted == slice_num)[0]
|
| 937 |
+
if len(pos) == 0:
|
| 938 |
+
continue
|
| 939 |
+
pos = pos[0]
|
| 940 |
+
|
| 941 |
+
gts = np.array(data['ground_truths'])[order]
|
| 942 |
+
preds = np.array(data['predictions'])[order]
|
| 943 |
+
probs = np.array(data['probabilities'])[order]
|
| 944 |
+
flairs = np.array(data['flairs'])[order]
|
| 945 |
+
|
| 946 |
+
gt_hw = gts[pos]
|
| 947 |
+
pred_hw = preds[pos]
|
| 948 |
+
prob_hw = probs[pos]
|
| 949 |
+
flair_hw = flairs[pos]
|
| 950 |
+
|
| 951 |
+
# Collapse one-hot GT if needed
|
| 952 |
+
if gt_hw.ndim == 3:
|
| 953 |
+
gt_hw = np.argmax(gt_hw, axis=-1)
|
| 954 |
+
|
| 955 |
+
rank = int(row['difficulty_rank'])
|
| 956 |
+
fname = (f"rank{rank:04d}_"
|
| 957 |
+
f"{patient_id}_slice{slice_num:03d}"
|
| 958 |
+
f"_dice{row['mean_fg_dice']:.3f}.png")
|
| 959 |
+
save_path = hard_slices_dir / fname
|
| 960 |
+
|
| 961 |
+
visualize_hard_slice(
|
| 962 |
+
flair=flair_hw,
|
| 963 |
+
gt_hw=gt_hw,
|
| 964 |
+
pred_hw=pred_hw,
|
| 965 |
+
prob_hw=prob_hw,
|
| 966 |
+
slice_metrics_row=row.to_dict(),
|
| 967 |
+
class_names=class_names,
|
| 968 |
+
num_classes=num_classes,
|
| 969 |
+
save_path=save_path,
|
| 970 |
+
rank=rank
|
| 971 |
+
)
|
| 972 |
+
|
| 973 |
+
print(f" ✅ Hard slice panels → {hard_slices_dir}")
|
| 974 |
+
|
| 975 |
+
# ── Step 6: patient summary visualizations ────────────────────────────
|
| 976 |
+
print(f"\nGenerating top-{top_n_patients} hardest patient summaries...")
|
| 977 |
+
hard_patients = patient_df.head(top_n_patients)
|
| 978 |
+
|
| 979 |
+
for _, pat_row in tqdm(hard_patients.iterrows(),
|
| 980 |
+
total=len(hard_patients), desc="Patient summaries"):
|
| 981 |
+
patient_id = pat_row['patient_id']
|
| 982 |
+
if patient_id not in patient_results:
|
| 983 |
+
continue
|
| 984 |
+
|
| 985 |
+
data = patient_results[patient_id]
|
| 986 |
+
slice_df_patient = slice_df[slice_df['patient_id'] == patient_id].copy()
|
| 987 |
+
|
| 988 |
+
rank = int(pat_row['difficulty_rank'])
|
| 989 |
+
comp = pat_row.get('composite_dice', float('nan'))
|
| 990 |
+
fname = (f"rank{rank:03d}_{patient_id}"
|
| 991 |
+
f"_composite{comp:.3f}.png")
|
| 992 |
+
save_path = patient_summaries_dir / fname
|
| 993 |
+
|
| 994 |
+
visualize_patient_summary(
|
| 995 |
+
patient_id=patient_id,
|
| 996 |
+
patient_data=data,
|
| 997 |
+
slice_df_patient=slice_df_patient,
|
| 998 |
+
class_names=class_names,
|
| 999 |
+
num_classes=num_classes,
|
| 1000 |
+
save_path=save_path
|
| 1001 |
+
)
|
| 1002 |
+
|
| 1003 |
+
print(f" ✅ Patient summaries → {patient_summaries_dir}")
|
| 1004 |
+
|
| 1005 |
+
# ── Step 7: print console summary ─────────────────────────────────────
|
| 1006 |
+
print("\n" + "=" * 70)
|
| 1007 |
+
print("ERROR ANALYSIS SUMMARY")
|
| 1008 |
+
print("=" * 70)
|
| 1009 |
+
print(f"\nTotal slices analysed : {len(slice_df)}")
|
| 1010 |
+
print(f"Total patients : {len(patient_df)}")
|
| 1011 |
+
|
| 1012 |
+
print(f"\nTop-10 Hardest Slices:")
|
| 1013 |
+
top10_cols = ['difficulty_rank', 'slice_id', 'mean_fg_dice',
|
| 1014 |
+
'error_rate', 'mean_confidence', 'difficulty_score']
|
| 1015 |
+
top10_cols = [c for c in top10_cols if c in slice_df.columns]
|
| 1016 |
+
print(slice_df[top10_cols].head(10).to_string(index=False))
|
| 1017 |
+
|
| 1018 |
+
print(f"\nTop-10 Hardest Patients:")
|
| 1019 |
+
fg_dice_cols = [f"{class_names[c].lower().replace(' ', '_')}_mean_dice"
|
| 1020 |
+
for c in range(1, num_classes)]
|
| 1021 |
+
pat_cols = ['difficulty_rank', 'patient_id', 'n_slices', 'composite_dice'] + \
|
| 1022 |
+
[c for c in fg_dice_cols if c in patient_df.columns]
|
| 1023 |
+
print(patient_df[pat_cols].head(10).to_string(index=False))
|
| 1024 |
+
|
| 1025 |
+
print("\n" + "=" * 70)
|
| 1026 |
+
print(f"All error analysis outputs → {error_dir}")
|
| 1027 |
+
print("=" * 70 + "\n")
|
| 1028 |
+
|
| 1029 |
+
return {
|
| 1030 |
+
'slice_df': slice_df,
|
| 1031 |
+
'patient_df': patient_df,
|
| 1032 |
+
'error_dir': error_dir
|
| 1033 |
+
}
|
models/for_WMH_Vent/model_training_scripts/p4_folds_results_aggregator.py
ADDED
|
@@ -0,0 +1,611 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
P4 - All U-Net models with Adaptive Loss (WCE + UFL)
|
| 3 |
+
|
| 4 |
+
WMH and Ventricles Segmentation with U-Net Models - Journal Paper Implementation
|
| 5 |
+
Three-class segmentation: Background vs Ventricles vs Abnormal WMH
|
| 6 |
+
Professional results saving and visualization for publication
|
| 7 |
+
|
| 8 |
+
This relates to our article:
|
| 9 |
+
"Deep Learning-Based Neuroanatomical Profiling Reveals Detailed Brain Changes:
|
| 10 |
+
A Large-Scale Multiple Sclerosis Study"
|
| 11 |
+
|
| 12 |
+
Features:
|
| 13 |
+
- Aggregatation of all inferenced results
|
| 14 |
+
- Includes lesion-level (connected-component) metrics: sensitivity, precision,
|
| 15 |
+
F1, TP/FP/FN lesion counts (added to address reviewer R1C7)
|
| 16 |
+
|
| 17 |
+
Authors:
|
| 18 |
+
"Mahdi Bashiri Bawil, Mousa Shamsi, Abolhassan Shakeri Bavil"
|
| 19 |
+
|
| 20 |
+
Developer:
|
| 21 |
+
"Mahdi Bashiri Bawil"
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
import os
|
| 25 |
+
import json
|
| 26 |
+
import pandas as pd
|
| 27 |
+
import numpy as np
|
| 28 |
+
from pathlib import Path
|
| 29 |
+
import warnings
|
| 30 |
+
warnings.filterwarnings('ignore')
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class ResultsAggregator:
|
| 34 |
+
"""
|
| 35 |
+
Aggregates segmentation results across multiple variants and folds.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(self, base_dir='./'):
|
| 39 |
+
"""
|
| 40 |
+
Initialize the aggregator.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
base_dir: Base directory containing all results folders
|
| 44 |
+
"""
|
| 45 |
+
self.base_dir = Path(base_dir)
|
| 46 |
+
self.variants = {
|
| 47 |
+
1: "unet",
|
| 48 |
+
2: "attnunet",
|
| 49 |
+
3: "dlv3unet",
|
| 50 |
+
4: "transunet"
|
| 51 |
+
}
|
| 52 |
+
self.class_names = ["Background", "Ventricles", "Abnormal_WMH"]
|
| 53 |
+
self.num_variants = 4
|
| 54 |
+
self.num_folds = 4
|
| 55 |
+
|
| 56 |
+
def find_results_folders(self):
|
| 57 |
+
"""Find all results folders matching the naming pattern."""
|
| 58 |
+
results_folders = []
|
| 59 |
+
for variant in range(self.num_variants):
|
| 60 |
+
for fold in range(self.num_folds):
|
| 61 |
+
folder_pattern = f"results_fold_{fold}_var_{variant+1}_zscore2"
|
| 62 |
+
folder_path = self.base_dir / folder_pattern
|
| 63 |
+
if folder_path.exists():
|
| 64 |
+
results_folders.append({
|
| 65 |
+
'variant': variant+1,
|
| 66 |
+
'fold': fold,
|
| 67 |
+
'path': folder_path
|
| 68 |
+
})
|
| 69 |
+
return results_folders
|
| 70 |
+
|
| 71 |
+
def load_test_metrics(self, results_folder):
|
| 72 |
+
"""Load test metrics from JSON file."""
|
| 73 |
+
metrics_path = results_folder['path'] / 'inference_all_test' / 'standard_3class' / 'metrics' / 'test_metrics_complete.json'
|
| 74 |
+
|
| 75 |
+
if not metrics_path.exists():
|
| 76 |
+
print(f"Warning: Metrics file not found at {metrics_path}")
|
| 77 |
+
return None
|
| 78 |
+
|
| 79 |
+
with open(metrics_path, 'r') as f:
|
| 80 |
+
data = json.load(f)
|
| 81 |
+
|
| 82 |
+
return data
|
| 83 |
+
|
| 84 |
+
def load_training_summary(self, results_folder):
|
| 85 |
+
"""Load training summary from JSON file (new format)."""
|
| 86 |
+
summary_path = results_folder['path'] / 'models' / 'standard_3class' / f"fold_{results_folder['fold']}" / 'training_summary.json'
|
| 87 |
+
|
| 88 |
+
if not summary_path.exists():
|
| 89 |
+
# Fallback to history.json if training_summary doesn't exist
|
| 90 |
+
return self.load_training_history(results_folder)
|
| 91 |
+
|
| 92 |
+
with open(summary_path, 'r') as f:
|
| 93 |
+
data = json.load(f)
|
| 94 |
+
|
| 95 |
+
return data
|
| 96 |
+
|
| 97 |
+
def load_training_history(self, results_folder):
|
| 98 |
+
"""Load training history from JSON file (legacy support)."""
|
| 99 |
+
history_path = results_folder['path'] / 'models' / 'standard_3class' / f"fold_{results_folder['fold']}" / 'history.json'
|
| 100 |
+
|
| 101 |
+
if not history_path.exists():
|
| 102 |
+
print(f"Warning: History file not found at {history_path}")
|
| 103 |
+
return None
|
| 104 |
+
|
| 105 |
+
with open(history_path, 'r') as f:
|
| 106 |
+
data = json.load(f)
|
| 107 |
+
|
| 108 |
+
return data
|
| 109 |
+
|
| 110 |
+
def load_best_epoch_analysis(self, results_folder):
|
| 111 |
+
"""Load best epoch analysis from JSON file (new format)."""
|
| 112 |
+
analysis_path = results_folder['path'] / 'models' / 'standard_3class' / f"fold_{results_folder['fold']}" / 'best_epoch_analysis.json'
|
| 113 |
+
|
| 114 |
+
if not analysis_path.exists():
|
| 115 |
+
return None
|
| 116 |
+
|
| 117 |
+
with open(analysis_path, 'r') as f:
|
| 118 |
+
data = json.load(f)
|
| 119 |
+
|
| 120 |
+
return data
|
| 121 |
+
|
| 122 |
+
def extract_test_metrics_row(self, results_folder, metrics_data):
|
| 123 |
+
"""
|
| 124 |
+
Extract a row of test metrics for the summary dataframe.
|
| 125 |
+
Includes both voxel-level and lesion-level metrics.
|
| 126 |
+
"""
|
| 127 |
+
if metrics_data is None:
|
| 128 |
+
return None
|
| 129 |
+
|
| 130 |
+
row = {
|
| 131 |
+
'Variant': results_folder['variant'],
|
| 132 |
+
'Variant_Name': self.variants[results_folder['variant']],
|
| 133 |
+
'Fold': results_folder['fold'],
|
| 134 |
+
'Test_Samples': metrics_data['config']['test_samples']
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
# ── Voxel-level metrics (unchanged) ─────────────────────────────────
|
| 138 |
+
for metric_name in ['dice', 'precision', 'recall', 'iou', 'specificity', 'hd95']:
|
| 139 |
+
metric_data = metrics_data['metrics'][metric_name]
|
| 140 |
+
|
| 141 |
+
for class_idx in range(3):
|
| 142 |
+
if class_idx != 0:
|
| 143 |
+
row[f'{metric_name.upper()}_class_{class_idx}'] = metric_data.get(f'class_{class_idx}')
|
| 144 |
+
|
| 145 |
+
row[f'{metric_name.upper()}_mean'] = metric_data.get('mean')
|
| 146 |
+
|
| 147 |
+
# ── Lesion-level metrics (new — R1C7) ────────────────────────────────
|
| 148 |
+
lesion_data = metrics_data['metrics'].get('lesion', None)
|
| 149 |
+
if lesion_data is not None:
|
| 150 |
+
for class_idx in range(2): # foreground classes only
|
| 151 |
+
key = f'class_{class_idx}'
|
| 152 |
+
cls = lesion_data.get(key, {})
|
| 153 |
+
|
| 154 |
+
# Scalar rates (averaged across patients in inference script)
|
| 155 |
+
for sk in ['lesion_sensitivity', 'lesion_precision', 'lesion_f1']:
|
| 156 |
+
col = f'LESION_{sk.upper()}_class_{class_idx}'
|
| 157 |
+
row[col] = cls.get(sk)
|
| 158 |
+
|
| 159 |
+
# Integer counts (summed across patients in inference script)
|
| 160 |
+
for ck in ['n_gt_lesions', 'n_pred_lesions', 'tp_lesions', 'fn_lesions', 'fp_lesions']:
|
| 161 |
+
col = f'LESION_{ck.upper()}_class_{class_idx}'
|
| 162 |
+
row[col] = cls.get(ck)
|
| 163 |
+
|
| 164 |
+
# Cross-class summary keys produced by aggregate_patient_metrics()
|
| 165 |
+
for sk in ['lesion_sensitivity', 'lesion_precision', 'lesion_f1']:
|
| 166 |
+
row[f'LESION_{sk.upper()}_mean'] = lesion_data.get(f'mean_{sk}')
|
| 167 |
+
for ck in ['n_gt_lesions', 'n_pred_lesions', 'tp_lesions', 'fn_lesions', 'fp_lesions']:
|
| 168 |
+
row[f'LESION_{ck.upper()}_total'] = lesion_data.get(f'total_{ck}')
|
| 169 |
+
|
| 170 |
+
return row
|
| 171 |
+
|
| 172 |
+
def extract_training_info_row(self, results_folder, training_data, best_epoch_analysis):
|
| 173 |
+
"""Extract training information including best epoch details."""
|
| 174 |
+
if training_data is None:
|
| 175 |
+
return None
|
| 176 |
+
|
| 177 |
+
row = {
|
| 178 |
+
'Variant': results_folder['variant'],
|
| 179 |
+
'Variant_Name': self.variants[results_folder['variant']],
|
| 180 |
+
'Fold': results_folder['fold']
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
# Try to extract from training_summary.json first
|
| 184 |
+
if isinstance(training_data, dict) and 'best_epoch_selection' in training_data:
|
| 185 |
+
row['Best_Epoch'] = training_data['best_epoch_selection']['overall_best_epoch']
|
| 186 |
+
row['Composite_Score'] = training_data['best_epoch_selection']['composite_score']
|
| 187 |
+
row['Total_Epochs'] = training_data['training_config']['total_epochs']
|
| 188 |
+
# Handle valid_epochs (only for Pix2Pix variants with beta scheduling)
|
| 189 |
+
if 'valid_epochs' in training_data['best_epoch_selection']:
|
| 190 |
+
row['First_Valid_Epoch'] = training_data['best_epoch_selection']['valid_epochs']['first_valid_epoch']
|
| 191 |
+
row['Total_Valid_Epochs'] = training_data['best_epoch_selection']['valid_epochs']['total_valid_epochs']
|
| 192 |
+
else:
|
| 193 |
+
row['First_Valid_Epoch'] = 1
|
| 194 |
+
row['Total_Valid_Epochs'] = training_data['training_config']['total_epochs']
|
| 195 |
+
|
| 196 |
+
# Best epoch metrics
|
| 197 |
+
best_metrics = training_data['best_epoch_metrics']
|
| 198 |
+
row['Best_Epoch_Val_Loss'] = best_metrics['val_loss']
|
| 199 |
+
row['Best_Epoch_Dice_Ventricles'] = best_metrics['dice']['class_1']
|
| 200 |
+
row['Best_Epoch_Dice_Abnormal_WMH'] = best_metrics['dice'].get('class_2', None)
|
| 201 |
+
row['Best_Epoch_Dice_Mean'] = best_metrics['dice']['mean']
|
| 202 |
+
|
| 203 |
+
# Priority metrics
|
| 204 |
+
row['Best_Abnormal_Epoch'] = training_data['priority_metrics']['abnormal_wmh']['best_epoch']
|
| 205 |
+
row['Best_Abnormal_Dice'] = training_data['priority_metrics']['abnormal_wmh']['best_dice']
|
| 206 |
+
row['Best_Ventricles_Epoch'] = training_data['priority_metrics']['ventricles']['best_epoch']
|
| 207 |
+
row['Best_Ventricles_Dice'] = training_data['priority_metrics']['ventricles']['best_dice']
|
| 208 |
+
|
| 209 |
+
# Fallback to best_epoch_analysis.json
|
| 210 |
+
elif best_epoch_analysis is not None:
|
| 211 |
+
row['Best_Epoch'] = best_epoch_analysis['best_overall_epoch']
|
| 212 |
+
row['Composite_Score'] = best_epoch_analysis['composite_score']
|
| 213 |
+
row['Total_Epochs'] = best_epoch_analysis['total_epochs']
|
| 214 |
+
row['First_Valid_Epoch'] = best_epoch_analysis['first_valid_epoch']
|
| 215 |
+
row['Total_Valid_Epochs'] = best_epoch_analysis['total_valid_epochs']
|
| 216 |
+
|
| 217 |
+
# Best epoch metrics
|
| 218 |
+
best_metrics = best_epoch_analysis['best_epoch_metrics']
|
| 219 |
+
row['Best_Epoch_Val_Loss'] = best_metrics['val_loss']
|
| 220 |
+
row['Best_Epoch_Dice_Ventricles'] = best_metrics['dice']['class_1']
|
| 221 |
+
row['Best_Epoch_Dice_Abnormal_WMH'] = best_metrics['dice'].get('class_2', None)
|
| 222 |
+
row['Best_Epoch_Dice_Mean'] = best_metrics['dice']['mean']
|
| 223 |
+
|
| 224 |
+
# Priority metrics
|
| 225 |
+
row['Best_Abnormal_Epoch'] = best_epoch_analysis['best_abnormal_epoch']
|
| 226 |
+
row['Best_Abnormal_Dice'] = best_epoch_analysis['best_abnormal_dice']
|
| 227 |
+
row['Best_Ventricles_Epoch'] = best_epoch_analysis['best_ventricles_epoch']
|
| 228 |
+
row['Best_Ventricles_Dice'] = best_epoch_analysis['best_ventricles_dice']
|
| 229 |
+
|
| 230 |
+
# Legacy fallback to history.json
|
| 231 |
+
elif isinstance(training_data, dict) and 'val_metrics' in training_data:
|
| 232 |
+
if 'best_epoch_analysis' in training_data:
|
| 233 |
+
analysis = training_data['best_epoch_analysis']
|
| 234 |
+
row['Best_Epoch'] = analysis['best_overall_epoch']
|
| 235 |
+
row['Composite_Score'] = analysis.get('composite_score', None)
|
| 236 |
+
else:
|
| 237 |
+
# Find best validation dice
|
| 238 |
+
val_dice_list = [m['dice']['mean'] for m in training_data['val_metrics']]
|
| 239 |
+
row['Best_Epoch'] = val_dice_list.index(max(val_dice_list)) + 1
|
| 240 |
+
row['Composite_Score'] = max(val_dice_list)
|
| 241 |
+
|
| 242 |
+
row['Total_Epochs'] = len(training_data['val_metrics'])
|
| 243 |
+
|
| 244 |
+
return row
|
| 245 |
+
|
| 246 |
+
def create_test_metrics_summary(self):
|
| 247 |
+
"""Create a comprehensive summary of test metrics."""
|
| 248 |
+
results_folders = self.find_results_folders()
|
| 249 |
+
|
| 250 |
+
if not results_folders:
|
| 251 |
+
print("No results folders found!")
|
| 252 |
+
return None
|
| 253 |
+
|
| 254 |
+
rows = []
|
| 255 |
+
for folder in results_folders:
|
| 256 |
+
metrics_data = self.load_test_metrics(folder)
|
| 257 |
+
row = self.extract_test_metrics_row(folder, metrics_data)
|
| 258 |
+
if row is not None:
|
| 259 |
+
rows.append(row)
|
| 260 |
+
|
| 261 |
+
df = pd.DataFrame(rows)
|
| 262 |
+
df = df.sort_values(['Variant', 'Fold']).reset_index(drop=True)
|
| 263 |
+
|
| 264 |
+
return df
|
| 265 |
+
|
| 266 |
+
def create_training_summary(self):
|
| 267 |
+
"""Create a comprehensive summary of training information."""
|
| 268 |
+
results_folders = self.find_results_folders()
|
| 269 |
+
|
| 270 |
+
if not results_folders:
|
| 271 |
+
print("No results folders found!")
|
| 272 |
+
return None
|
| 273 |
+
|
| 274 |
+
rows = []
|
| 275 |
+
for folder in results_folders:
|
| 276 |
+
training_data = self.load_training_summary(folder)
|
| 277 |
+
best_epoch_analysis = self.load_best_epoch_analysis(folder)
|
| 278 |
+
row = self.extract_training_info_row(folder, training_data, best_epoch_analysis)
|
| 279 |
+
if row is not None:
|
| 280 |
+
rows.append(row)
|
| 281 |
+
|
| 282 |
+
df = pd.DataFrame(rows)
|
| 283 |
+
df = df.sort_values(['Variant', 'Fold']).reset_index(drop=True)
|
| 284 |
+
|
| 285 |
+
return df
|
| 286 |
+
|
| 287 |
+
def create_per_class_summary(self, test_metrics_df):
|
| 288 |
+
"""
|
| 289 |
+
Create per-class summary statistics across folds for each variant.
|
| 290 |
+
Includes both voxel-level and lesion-level metrics.
|
| 291 |
+
"""
|
| 292 |
+
summaries = []
|
| 293 |
+
|
| 294 |
+
for variant in range(self.num_variants +1):
|
| 295 |
+
variant_data = test_metrics_df[test_metrics_df['Variant'] == variant]
|
| 296 |
+
|
| 297 |
+
if len(variant_data) == 0:
|
| 298 |
+
continue
|
| 299 |
+
|
| 300 |
+
for class_idx in range(3):
|
| 301 |
+
if class_idx == 0:
|
| 302 |
+
continue
|
| 303 |
+
|
| 304 |
+
class_summary = {
|
| 305 |
+
'Variant': variant,
|
| 306 |
+
'Variant_Name': self.variants[variant],
|
| 307 |
+
'Class': class_idx,
|
| 308 |
+
'Class_Name': self.class_names[class_idx]
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
# Voxel-level metrics
|
| 312 |
+
for metric in ['DICE', 'PRECISION', 'RECALL', 'IOU', 'SPECIFICITY', 'HD95']:
|
| 313 |
+
col_name = f'{metric}_class_{class_idx}'
|
| 314 |
+
if col_name in variant_data.columns:
|
| 315 |
+
values = variant_data[col_name].dropna().values
|
| 316 |
+
class_summary[f'{metric}_mean'] = np.mean(values)
|
| 317 |
+
class_summary[f'{metric}_std'] = np.std(values)
|
| 318 |
+
class_summary[f'{metric}_min'] = np.min(values)
|
| 319 |
+
class_summary[f'{metric}_max'] = np.max(values)
|
| 320 |
+
|
| 321 |
+
# Lesion-level scalar metrics (mean ± std across folds)
|
| 322 |
+
for sk in ['LESION_SENSITIVITY', 'LESION_PRECISION', 'LESION_F1']:
|
| 323 |
+
col_name = f'LESION_{sk}_class_{class_idx}'
|
| 324 |
+
if col_name in variant_data.columns:
|
| 325 |
+
values = variant_data[col_name].dropna().values
|
| 326 |
+
class_summary[f'{sk}_mean'] = np.mean(values) if len(values) else np.nan
|
| 327 |
+
class_summary[f'{sk}_std'] = np.std(values) if len(values) else np.nan
|
| 328 |
+
|
| 329 |
+
# Lesion-level count metrics (sum across folds — total pool)
|
| 330 |
+
for ck in ['N_GT_LESIONS', 'N_PRED_LESIONS', 'TP_LESIONS', 'FN_LESIONS', 'FP_LESIONS']:
|
| 331 |
+
col_name = f'LESION_{ck}_class_{class_idx}'
|
| 332 |
+
if col_name in variant_data.columns:
|
| 333 |
+
values = variant_data[col_name].dropna().values
|
| 334 |
+
class_summary[f'LESION_{ck}_total'] = int(np.sum(values)) if len(values) else 0
|
| 335 |
+
|
| 336 |
+
summaries.append(class_summary)
|
| 337 |
+
|
| 338 |
+
df = pd.DataFrame(summaries)
|
| 339 |
+
return df
|
| 340 |
+
|
| 341 |
+
def create_variant_comparison(self, test_metrics_df):
|
| 342 |
+
"""
|
| 343 |
+
Create a variant comparison table with mean ± std across folds.
|
| 344 |
+
Includes both voxel-level and lesion-level metrics.
|
| 345 |
+
"""
|
| 346 |
+
comparisons = []
|
| 347 |
+
|
| 348 |
+
for variant in range(self.num_variants + 1):
|
| 349 |
+
variant_data = test_metrics_df[test_metrics_df['Variant'] == variant]
|
| 350 |
+
|
| 351 |
+
if len(variant_data) == 0:
|
| 352 |
+
continue
|
| 353 |
+
|
| 354 |
+
comparison = {
|
| 355 |
+
'Variant': variant,
|
| 356 |
+
'Variant_Name': self.variants[variant],
|
| 357 |
+
'N_Folds': len(variant_data)
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
# ── Voxel-level metrics ──────────────────────────────────────────
|
| 361 |
+
for metric in ['DICE', 'PRECISION', 'RECALL', 'IOU', 'SPECIFICITY', 'HD95']:
|
| 362 |
+
# Overall mean across classes
|
| 363 |
+
col_name = f'{metric}_mean'
|
| 364 |
+
if col_name in variant_data.columns:
|
| 365 |
+
values = variant_data[col_name].dropna().values
|
| 366 |
+
comparison[f'{metric}_Mean'] = np.mean(values)
|
| 367 |
+
comparison[f'{metric}_Std'] = np.std(values)
|
| 368 |
+
|
| 369 |
+
# Per-class (Ventricles=1, Abnormal_WMH=2)
|
| 370 |
+
for class_idx in [1, 2]:
|
| 371 |
+
col_name = f'{metric}_class_{class_idx}'
|
| 372 |
+
if col_name in variant_data.columns:
|
| 373 |
+
values = variant_data[col_name].dropna().values
|
| 374 |
+
comparison[f'{metric}_Class{class_idx}_Mean'] = np.mean(values)
|
| 375 |
+
comparison[f'{metric}_Class{class_idx}_Std'] = np.std(values)
|
| 376 |
+
|
| 377 |
+
# ── Lesion-level scalar metrics (mean ± std across folds) ────────
|
| 378 |
+
for sk_suffix in ['LESION_SENSITIVITY', 'LESION_PRECISION', 'LESION_F1']:
|
| 379 |
+
# Cross-class mean
|
| 380 |
+
col_name = f'LESION_{sk_suffix}_mean'
|
| 381 |
+
if col_name in variant_data.columns:
|
| 382 |
+
values = variant_data[col_name].dropna().values
|
| 383 |
+
comparison[f'{sk_suffix}_Mean'] = np.mean(values) if len(values) else np.nan
|
| 384 |
+
comparison[f'{sk_suffix}_Std'] = np.std(values) if len(values) else np.nan
|
| 385 |
+
|
| 386 |
+
# Per-class
|
| 387 |
+
for class_idx in [2]:
|
| 388 |
+
col_name = f'LESION_{sk_suffix}_class_{class_idx}'
|
| 389 |
+
if col_name in variant_data.columns:
|
| 390 |
+
values = variant_data[col_name].dropna().values
|
| 391 |
+
comparison[f'{sk_suffix}_Class{class_idx}_Mean'] = np.mean(values) if len(values) else np.nan
|
| 392 |
+
comparison[f'{sk_suffix}_Class{class_idx}_Std'] = np.std(values) if len(values) else np.nan
|
| 393 |
+
|
| 394 |
+
# ── Lesion-level count metrics (sum across folds) ────────────────
|
| 395 |
+
for ck in ['N_GT_LESIONS', 'N_PRED_LESIONS', 'TP_LESIONS', 'FN_LESIONS', 'FP_LESIONS']:
|
| 396 |
+
# Total across all classes
|
| 397 |
+
col_name = f'LESION_{ck}_total'
|
| 398 |
+
if col_name in variant_data.columns:
|
| 399 |
+
values = variant_data[col_name].dropna().values
|
| 400 |
+
comparison[f'LESION_{ck}_Total'] = int(np.sum(values)) if len(values) else 0
|
| 401 |
+
|
| 402 |
+
# Per-class totals
|
| 403 |
+
for class_idx in [2]:
|
| 404 |
+
col_name = f'LESION_{ck}_class_{class_idx}'
|
| 405 |
+
if col_name in variant_data.columns:
|
| 406 |
+
values = variant_data[col_name].dropna().values
|
| 407 |
+
comparison[f'LESION_{ck}_Class{class_idx}_Total'] = int(np.sum(values)) if len(values) else 0
|
| 408 |
+
|
| 409 |
+
comparisons.append(comparison)
|
| 410 |
+
|
| 411 |
+
df = pd.DataFrame(comparisons)
|
| 412 |
+
return df
|
| 413 |
+
|
| 414 |
+
def create_training_comparison(self, training_df):
|
| 415 |
+
"""Create training comparison showing convergence patterns."""
|
| 416 |
+
if training_df is None:
|
| 417 |
+
return None
|
| 418 |
+
|
| 419 |
+
comparisons = []
|
| 420 |
+
|
| 421 |
+
for variant in range(self.num_variants + 1):
|
| 422 |
+
variant_data = training_df[training_df['Variant'] == variant]
|
| 423 |
+
|
| 424 |
+
if len(variant_data) == 0:
|
| 425 |
+
continue
|
| 426 |
+
|
| 427 |
+
comparison = {
|
| 428 |
+
'Variant': variant,
|
| 429 |
+
'Variant_Name': self.variants[variant],
|
| 430 |
+
'N_Folds': len(variant_data)
|
| 431 |
+
}
|
| 432 |
+
|
| 433 |
+
# Best epoch statistics
|
| 434 |
+
if 'Best_Epoch' in variant_data.columns:
|
| 435 |
+
comparison['Best_Epoch_Mean'] = np.mean(variant_data['Best_Epoch'].values)
|
| 436 |
+
comparison['Best_Epoch_Std'] = np.std(variant_data['Best_Epoch'].values)
|
| 437 |
+
comparison['Best_Epoch_Min'] = np.min(variant_data['Best_Epoch'].values)
|
| 438 |
+
comparison['Best_Epoch_Max'] = np.max(variant_data['Best_Epoch'].values)
|
| 439 |
+
|
| 440 |
+
# Composite score statistics
|
| 441 |
+
if 'Composite_Score' in variant_data.columns:
|
| 442 |
+
comparison['Composite_Score_Mean'] = np.mean(variant_data['Composite_Score'].dropna().values)
|
| 443 |
+
comparison['Composite_Score_Std'] = np.std(variant_data['Composite_Score'].dropna().values)
|
| 444 |
+
|
| 445 |
+
# Validation metrics at best epoch
|
| 446 |
+
for metric_col in ['Best_Epoch_Val_Loss', 'Best_Epoch_Dice_Mean',
|
| 447 |
+
'Best_Epoch_Dice_Ventricles', 'Best_Epoch_Dice_Abnormal_WMH']:
|
| 448 |
+
if metric_col in variant_data.columns:
|
| 449 |
+
values = variant_data[metric_col].dropna().values
|
| 450 |
+
if len(values) > 0:
|
| 451 |
+
comparison[f'{metric_col}_Mean'] = np.mean(values)
|
| 452 |
+
comparison[f'{metric_col}_Std'] = np.std(values)
|
| 453 |
+
|
| 454 |
+
comparisons.append(comparison)
|
| 455 |
+
|
| 456 |
+
df = pd.DataFrame(comparisons)
|
| 457 |
+
return df
|
| 458 |
+
|
| 459 |
+
def generate_all_summaries(self, output_dir='./folds_results'):
|
| 460 |
+
"""Generate all summary CSV files."""
|
| 461 |
+
output_path = Path(output_dir)
|
| 462 |
+
output_path.mkdir(exist_ok=True)
|
| 463 |
+
|
| 464 |
+
print("=" * 80)
|
| 465 |
+
print("RESULTS AGGREGATION STARTED")
|
| 466 |
+
print("=" * 80)
|
| 467 |
+
|
| 468 |
+
# 1. Test Metrics Summary (all variants, all folds)
|
| 469 |
+
print("\n1. Generating test metrics summary...")
|
| 470 |
+
test_metrics_df = self.create_test_metrics_summary()
|
| 471 |
+
if test_metrics_df is not None:
|
| 472 |
+
output_file = output_path / 'test_metrics_all_variants_folds.csv'
|
| 473 |
+
test_metrics_df.to_csv(output_file, index=False)
|
| 474 |
+
print(f" ✓ Saved: {output_file}")
|
| 475 |
+
print(f" - Shape: {test_metrics_df.shape}")
|
| 476 |
+
|
| 477 |
+
# 2. Training Summary
|
| 478 |
+
print("\n2. Generating training summary...")
|
| 479 |
+
training_df = self.create_training_summary()
|
| 480 |
+
if training_df is not None:
|
| 481 |
+
output_file = output_path / 'training_info_all_variants_folds.csv'
|
| 482 |
+
training_df.to_csv(output_file, index=False)
|
| 483 |
+
print(f" ✓ Saved: {output_file}")
|
| 484 |
+
print(f" - Shape: {training_df.shape}")
|
| 485 |
+
|
| 486 |
+
# 3. Per-Class Summary
|
| 487 |
+
print("\n3. Generating per-class summary...")
|
| 488 |
+
per_class_df = None
|
| 489 |
+
if test_metrics_df is not None:
|
| 490 |
+
per_class_df = self.create_per_class_summary(test_metrics_df)
|
| 491 |
+
output_file = output_path / 'per_class_summary.csv'
|
| 492 |
+
per_class_df.to_csv(output_file, index=False)
|
| 493 |
+
print(f" ✓ Saved: {output_file}")
|
| 494 |
+
print(f" - Shape: {per_class_df.shape}")
|
| 495 |
+
|
| 496 |
+
# 4. Variant Comparison (Test Metrics)
|
| 497 |
+
print("\n4. Generating variant comparison (test metrics)...")
|
| 498 |
+
variant_comparison_df = None
|
| 499 |
+
if test_metrics_df is not None:
|
| 500 |
+
variant_comparison_df = self.create_variant_comparison(test_metrics_df)
|
| 501 |
+
output_file = output_path / 'variant_comparison_test.csv'
|
| 502 |
+
variant_comparison_df.to_csv(output_file, index=False)
|
| 503 |
+
print(f" ✓ Saved: {output_file}")
|
| 504 |
+
print(f" - Shape: {variant_comparison_df.shape}")
|
| 505 |
+
|
| 506 |
+
# 5. Variant Comparison (Training)
|
| 507 |
+
print("\n5. Generating variant comparison (training)...")
|
| 508 |
+
training_comparison_df = None
|
| 509 |
+
if training_df is not None:
|
| 510 |
+
training_comparison_df = self.create_training_comparison(training_df)
|
| 511 |
+
if training_comparison_df is not None:
|
| 512 |
+
output_file = output_path / 'variant_comparison_training.csv'
|
| 513 |
+
training_comparison_df.to_csv(output_file, index=False)
|
| 514 |
+
print(f" ✓ Saved: {output_file}")
|
| 515 |
+
print(f" - Shape: {training_comparison_df.shape}")
|
| 516 |
+
|
| 517 |
+
print("\n" + "=" * 80)
|
| 518 |
+
print("AGGREGATION COMPLETE")
|
| 519 |
+
print("=" * 80)
|
| 520 |
+
|
| 521 |
+
return {
|
| 522 |
+
'test_metrics': test_metrics_df,
|
| 523 |
+
'training_info': training_df,
|
| 524 |
+
'per_class': per_class_df,
|
| 525 |
+
'variant_comparison_test': variant_comparison_df,
|
| 526 |
+
'variant_comparison_training': training_comparison_df
|
| 527 |
+
}
|
| 528 |
+
|
| 529 |
+
def print_summary_statistics(self, dfs):
|
| 530 |
+
"""Print summary statistics to console."""
|
| 531 |
+
print("\n" + "=" * 80)
|
| 532 |
+
print("SUMMARY STATISTICS")
|
| 533 |
+
print("=" * 80)
|
| 534 |
+
|
| 535 |
+
if dfs['variant_comparison_test'] is not None:
|
| 536 |
+
|
| 537 |
+
# ── Voxel-level Dice ─────────────────────────────────────────────
|
| 538 |
+
print("\n📊 TEST DICE SCORES (Mean ± Std) across folds:")
|
| 539 |
+
print("-" * 80)
|
| 540 |
+
for _, row in dfs['variant_comparison_test'].iterrows():
|
| 541 |
+
print(f"\nVariant {row['Variant']}: {row['Variant_Name']}")
|
| 542 |
+
print(f" Overall: {row['DICE_Mean']:.4f} ± {row['DICE_Std']:.4f}")
|
| 543 |
+
print(f" Ventricles: {row['DICE_Class1_Mean']:.4f} ± {row['DICE_Class1_Std']:.4f}")
|
| 544 |
+
print(f" Abnormal WMH: {row['DICE_Class2_Mean']:.4f} ± {row['DICE_Class2_Std']:.4f}")
|
| 545 |
+
|
| 546 |
+
# ── Lesion-level metrics ─────────────────────────────────────────
|
| 547 |
+
lesion_cols_present = any(
|
| 548 |
+
col.startswith('LESION_') for col in dfs['variant_comparison_test'].columns
|
| 549 |
+
)
|
| 550 |
+
if lesion_cols_present:
|
| 551 |
+
print("\n\n🔬 LESION-LEVEL METRICS (Mean ± Std) across folds:")
|
| 552 |
+
print("-" * 80)
|
| 553 |
+
for _, row in dfs['variant_comparison_test'].iterrows():
|
| 554 |
+
print(f"\nVariant {row['Variant']}: {row['Variant_Name']}")
|
| 555 |
+
|
| 556 |
+
# Per-class
|
| 557 |
+
for class_idx, class_name in [(2, 'Abnormal WMH')]:
|
| 558 |
+
sens_col = f'LESION_LESION_SENSITIVITY_Class{class_idx}_Mean'
|
| 559 |
+
prec_col = f'LESION_LESION_PRECISION_Class{class_idx}_Mean'
|
| 560 |
+
f1_col = f'LESION_LESION_F1_Class{class_idx}_Mean'
|
| 561 |
+
tp_col = f'LESION_TP_LESIONS_Class{class_idx}_Total'
|
| 562 |
+
fp_col = f'LESION_FP_LESIONS_Class{class_idx}_Total'
|
| 563 |
+
fn_col = f'LESION_FN_LESIONS_Class{class_idx}_Total'
|
| 564 |
+
gt_col = f'LESION_N_GT_LESIONS_Class{class_idx}_Total'
|
| 565 |
+
|
| 566 |
+
print(f" [{class_name}]")
|
| 567 |
+
if sens_col in row:
|
| 568 |
+
s_m = f"{row[sens_col]:.4f}" if pd.notna(row.get(sens_col)) else 'N/A'
|
| 569 |
+
s_s = f"{row.get(f'LESION_LESION_SENSITIVITY_Class{class_idx}_Std', float('nan')):.4f}"
|
| 570 |
+
p_m = f"{row[prec_col]:.4f}" if pd.notna(row.get(prec_col)) else 'N/A'
|
| 571 |
+
p_s = f"{row.get(f'LESION_LESION_PRECISION_Class{class_idx}_Std', float('nan')):.4f}"
|
| 572 |
+
f_m = f"{row[f1_col]:.4f}" if pd.notna(row.get(f1_col)) else 'N/A'
|
| 573 |
+
f_s = f"{row.get(f'LESION_LESION_F1_Class{class_idx}_Std', float('nan')):.4f}"
|
| 574 |
+
print(f" Sensitivity : {s_m} ± {s_s}")
|
| 575 |
+
print(f" Precision : {p_m} ± {p_s}")
|
| 576 |
+
print(f" F1 : {f_m} ± {f_s}")
|
| 577 |
+
if gt_col in row:
|
| 578 |
+
print(f" GT Lesions : {int(row.get(gt_col, 0))} "
|
| 579 |
+
f"TP: {int(row.get(tp_col, 0))} "
|
| 580 |
+
f"FP: {int(row.get(fp_col, 0))} "
|
| 581 |
+
f"FN: {int(row.get(fn_col, 0))}")
|
| 582 |
+
|
| 583 |
+
if dfs['variant_comparison_training'] is not None:
|
| 584 |
+
print("\n\n🏆 TRAINING CONVERGENCE:")
|
| 585 |
+
print("-" * 80)
|
| 586 |
+
for _, row in dfs['variant_comparison_training'].iterrows():
|
| 587 |
+
print(f"\nVariant {row['Variant']}: {row['Variant_Name']}")
|
| 588 |
+
if 'Best_Epoch_Mean' in row:
|
| 589 |
+
print(f" Best Epoch: {row['Best_Epoch_Mean']:.1f} ± {row['Best_Epoch_Std']:.1f}")
|
| 590 |
+
if 'Best_Epoch_Dice_Abnormal_WMH_Mean' in row:
|
| 591 |
+
print(f" Val Abnormal: {row['Best_Epoch_Dice_Abnormal_WMH_Mean']:.4f} ± {row['Best_Epoch_Dice_Abnormal_WMH_Std']:.4f}")
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
# Main execution
|
| 595 |
+
if __name__ == "__main__":
|
| 596 |
+
# Initialize aggregator
|
| 597 |
+
aggregator = ResultsAggregator(base_dir='./')
|
| 598 |
+
|
| 599 |
+
# Generate all summaries
|
| 600 |
+
dfs = aggregator.generate_all_summaries(output_dir='./folds_results_zscore2_all')
|
| 601 |
+
|
| 602 |
+
# Print summary statistics
|
| 603 |
+
aggregator.print_summary_statistics(dfs)
|
| 604 |
+
|
| 605 |
+
print("\n✓ All CSV files have been generated in './folds_results_zscore2_all' directory")
|
| 606 |
+
print("\nGenerated files:")
|
| 607 |
+
print(" 1. test_metrics_all_variants_folds.csv - Complete test metrics (voxel + lesion level)")
|
| 608 |
+
print(" 2. training_info_all_variants_folds.csv - Training convergence info")
|
| 609 |
+
print(" 3. per_class_summary.csv - Per-class statistics (voxel + lesion level)")
|
| 610 |
+
print(" 4. variant_comparison_test.csv - Test metrics comparison (voxel + lesion level)")
|
| 611 |
+
print(" 5. variant_comparison_training.csv - Training comparison")
|
models/for_WMH_Vent/model_training_scripts/p4_inference.py
ADDED
|
@@ -0,0 +1,1146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
P4 Article - Inference Script for ventricles and WMH segmentation task
|
| 3 |
+
|
| 4 |
+
Developer:
|
| 5 |
+
Mahdi Bashiri Bawil
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import tensorflow as tf
|
| 9 |
+
import os
|
| 10 |
+
from collections import defaultdict
|
| 11 |
+
import numpy as np
|
| 12 |
+
import matplotlib.pyplot as plt
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
import json
|
| 16 |
+
import nibabel as nib
|
| 17 |
+
import seaborn as sns
|
| 18 |
+
from sklearn.metrics import confusion_matrix, cohen_kappa_score, classification_report
|
| 19 |
+
|
| 20 |
+
from scipy.spatial.distance import directed_hausdorff
|
| 21 |
+
from scipy.ndimage import distance_transform_edt
|
| 22 |
+
from scipy.spatial.distance import cdist
|
| 23 |
+
from scipy.ndimage import binary_erosion
|
| 24 |
+
from scipy.ndimage import label as nd_label
|
| 25 |
+
|
| 26 |
+
from unet_model import build_unet_3class # must be updated with the actual used model for traininig
|
| 27 |
+
|
| 28 |
+
# Import data loader
|
| 29 |
+
from p4_data_loader import DataConfig, P2DataLoader
|
| 30 |
+
|
| 31 |
+
# Error analysis
|
| 32 |
+
from p4_error_analysis import run_error_analysis
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
print("TensorFlow Version:", tf.__version__)
|
| 36 |
+
|
| 37 |
+
###################### GPU Configuration ######################
|
| 38 |
+
|
| 39 |
+
# Configure GPU memory growth
|
| 40 |
+
physical_devices = tf.config.list_physical_devices('GPU')
|
| 41 |
+
if physical_devices:
|
| 42 |
+
try:
|
| 43 |
+
for device in physical_devices:
|
| 44 |
+
tf.config.experimental.set_memory_growth(device, True)
|
| 45 |
+
print("✅ GPU memory growth enabled")
|
| 46 |
+
print(f" Available GPUs: {len(physical_devices)}")
|
| 47 |
+
except RuntimeError as e:
|
| 48 |
+
print(f"GPU configuration error: {e}")
|
| 49 |
+
else:
|
| 50 |
+
print("⚠️ No GPU detected - inference will be slow")
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
###################### Inference Configuration ######################
|
| 54 |
+
|
| 55 |
+
class InferenceConfig:
|
| 56 |
+
"""Configuration for inference"""
|
| 57 |
+
|
| 58 |
+
def __init__(self,
|
| 59 |
+
variant: int = 5,
|
| 60 |
+
preprocessing: str = 'standard',
|
| 61 |
+
class_scenario: str = '4class',
|
| 62 |
+
fold_id: int = 0,
|
| 63 |
+
model_name: str = 'best_dice_generator.h5',
|
| 64 |
+
architecture_name: str = 'unet'
|
| 65 |
+
):
|
| 66 |
+
|
| 67 |
+
# Experiment identification
|
| 68 |
+
self.variant = variant
|
| 69 |
+
self.preprocessing = preprocessing
|
| 70 |
+
self.class_scenario = class_scenario
|
| 71 |
+
self.fold_id = fold_id
|
| 72 |
+
self.model_name = model_name
|
| 73 |
+
self.architecture_name = architecture_name
|
| 74 |
+
|
| 75 |
+
# Number of classes
|
| 76 |
+
self.num_classes = 3 if class_scenario == '3class' else 4
|
| 77 |
+
|
| 78 |
+
# Class names
|
| 79 |
+
if self.num_classes == 4:
|
| 80 |
+
self.class_names = ['Background', 'Ventricles', 'Normal_WMH', 'Abnormal_WMH']
|
| 81 |
+
elif self.num_classes == 3:
|
| 82 |
+
self.class_names = ['Background', 'Ventricles', 'Abnormal_WMH']
|
| 83 |
+
|
| 84 |
+
# Image dimensions
|
| 85 |
+
self.batch_size = 1 # Use batch_size=1 for inference
|
| 86 |
+
self.img_width = 256
|
| 87 |
+
self.img_height = 256
|
| 88 |
+
|
| 89 |
+
# Paths
|
| 90 |
+
self.results_dir = Path(f"results_fold_{fold_id}_var_{variant}_zscore2")
|
| 91 |
+
self.models_dir = self.results_dir / "models" / f"{preprocessing}_{class_scenario}"
|
| 92 |
+
self.checkpoint_dir = self.models_dir / f"fold_{fold_id}"
|
| 93 |
+
|
| 94 |
+
# Output directories
|
| 95 |
+
self.inference_dir = self.results_dir / "inference_all_test" / f"{preprocessing}_{class_scenario}"
|
| 96 |
+
# self.predictions_dir = self.inference_dir / "predictions"
|
| 97 |
+
self.visualizations_dir = self.inference_dir / "visualizations"
|
| 98 |
+
self.metrics_dir = self.inference_dir / "metrics"
|
| 99 |
+
|
| 100 |
+
# Create directories
|
| 101 |
+
# self.predictions_dir.mkdir(parents=True, exist_ok=True)
|
| 102 |
+
self.visualizations_dir.mkdir(parents=True, exist_ok=True)
|
| 103 |
+
self.metrics_dir.mkdir(parents=True, exist_ok=True)
|
| 104 |
+
|
| 105 |
+
# Model path
|
| 106 |
+
self.model_path = self.checkpoint_dir / self.model_name
|
| 107 |
+
|
| 108 |
+
# Check if model exists
|
| 109 |
+
if not self.model_path.exists():
|
| 110 |
+
raise FileNotFoundError(f"Model not found: {self.model_path}")
|
| 111 |
+
|
| 112 |
+
print(f"\n{'='*70}")
|
| 113 |
+
print(f"INFERENCE CONFIGURATION")
|
| 114 |
+
print(f"{'='*70}")
|
| 115 |
+
print(f"Variant: {self.variant}")
|
| 116 |
+
print(f"Preprocessing: {self.preprocessing}")
|
| 117 |
+
print(f"Class scenario: {self.class_scenario} ({self.num_classes} classes)")
|
| 118 |
+
print(f"Fold: {self.fold_id}")
|
| 119 |
+
print(f"Architecture: {self.architecture_name}")
|
| 120 |
+
print(f"Model: {self.model_name}")
|
| 121 |
+
print(f"Model path: {self.model_path}")
|
| 122 |
+
print(f"Output directory: {self.inference_dir}")
|
| 123 |
+
print(f"{'='*70}\n")
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
###################### Utility Functions ######################
|
| 127 |
+
|
| 128 |
+
def prepare_input(paired_input):
|
| 129 |
+
"""
|
| 130 |
+
Extract and normalize FLAIR from paired input
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
paired_input: (bs, 256, 512, 1) with FLAIR + mask
|
| 134 |
+
|
| 135 |
+
Returns:
|
| 136 |
+
flair_normalized: FLAIR normalized to [-1, 1]
|
| 137 |
+
"""
|
| 138 |
+
# Extract FLAIR (left half)
|
| 139 |
+
flair_normalized = paired_input[:, :, :256, :]
|
| 140 |
+
return flair_normalized
|
| 141 |
+
|
| 142 |
+
def compute_hd95(mask1, mask2):
|
| 143 |
+
"""
|
| 144 |
+
Compute 95th percentile Hausdorff Distance between two binary masks
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
mask1: Binary mask 1
|
| 148 |
+
mask2: Binary mask 2
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
HD95 value in pixels
|
| 152 |
+
"""
|
| 153 |
+
# Get boundary points
|
| 154 |
+
if not np.any(mask1) or not np.any(mask2):
|
| 155 |
+
return np.nan
|
| 156 |
+
|
| 157 |
+
# Compute distance transforms
|
| 158 |
+
dt1 = distance_transform_edt(~mask1.astype(bool))
|
| 159 |
+
dt2 = distance_transform_edt(~mask2.astype(bool))
|
| 160 |
+
|
| 161 |
+
# Get surface points
|
| 162 |
+
surface1 = mask1.astype(bool) & (dt1 <= 1)
|
| 163 |
+
surface2 = mask2.astype(bool) & (dt2 <= 1)
|
| 164 |
+
|
| 165 |
+
if not np.any(surface1) or not np.any(surface2):
|
| 166 |
+
return np.nan
|
| 167 |
+
|
| 168 |
+
# Get coordinates of surface points
|
| 169 |
+
coords1 = np.argwhere(surface1)
|
| 170 |
+
coords2 = np.argwhere(surface2)
|
| 171 |
+
|
| 172 |
+
# Compute distances from surface1 to surface2
|
| 173 |
+
distances1 = np.min(np.sqrt(np.sum((coords1[:, np.newaxis, :] - coords2[np.newaxis, :, :]) ** 2, axis=2)), axis=1)
|
| 174 |
+
# Compute distances from surface2 to surface1
|
| 175 |
+
distances2 = np.min(np.sqrt(np.sum((coords2[:, np.newaxis, :] - coords1[np.newaxis, :, :]) ** 2, axis=2)), axis=1)
|
| 176 |
+
|
| 177 |
+
# Combine distances
|
| 178 |
+
all_distances = np.concatenate([distances1, distances2])
|
| 179 |
+
|
| 180 |
+
# Return 95th percentile
|
| 181 |
+
return np.percentile(all_distances, 95)
|
| 182 |
+
|
| 183 |
+
def compute_hd95_3d(mask1, mask2):
|
| 184 |
+
"""
|
| 185 |
+
Compute 95th percentile Hausdorff Distance for 3D volume
|
| 186 |
+
Uses only surface voxels for efficiency
|
| 187 |
+
|
| 188 |
+
Args:
|
| 189 |
+
mask1: Binary mask (N, H, W)
|
| 190 |
+
mask2: Binary mask (N, H, W)
|
| 191 |
+
|
| 192 |
+
Returns:
|
| 193 |
+
HD95 value in pixels
|
| 194 |
+
"""
|
| 195 |
+
if not np.any(mask1) or not np.any(mask2):
|
| 196 |
+
return np.nan
|
| 197 |
+
|
| 198 |
+
# Extract surface voxels only (border voxels)
|
| 199 |
+
from scipy.ndimage import binary_erosion
|
| 200 |
+
|
| 201 |
+
# Surface = original mask minus eroded mask
|
| 202 |
+
surface1 = mask1.astype(bool) & ~binary_erosion(mask1.astype(bool))
|
| 203 |
+
surface2 = mask2.astype(bool) & ~binary_erosion(mask2.astype(bool))
|
| 204 |
+
|
| 205 |
+
# Get surface coordinates
|
| 206 |
+
coords1 = np.argwhere(surface1)
|
| 207 |
+
coords2 = np.argwhere(surface2)
|
| 208 |
+
|
| 209 |
+
if len(coords1) == 0 or len(coords2) == 0:
|
| 210 |
+
return np.nan
|
| 211 |
+
|
| 212 |
+
# Subsample if still too large (>10k points each)
|
| 213 |
+
max_points = 10000
|
| 214 |
+
if len(coords1) > max_points:
|
| 215 |
+
idx1 = np.random.choice(len(coords1), max_points, replace=False)
|
| 216 |
+
coords1 = coords1[idx1]
|
| 217 |
+
if len(coords2) > max_points:
|
| 218 |
+
idx2 = np.random.choice(len(coords2), max_points, replace=False)
|
| 219 |
+
coords2 = coords2[idx2]
|
| 220 |
+
|
| 221 |
+
# Compute distances
|
| 222 |
+
distances1 = np.min(cdist(coords1, coords2, metric='euclidean'), axis=1)
|
| 223 |
+
distances2 = np.min(cdist(coords2, coords1, metric='euclidean'), axis=1)
|
| 224 |
+
|
| 225 |
+
# Combine all distances
|
| 226 |
+
all_distances = np.concatenate([distances1, distances2])
|
| 227 |
+
|
| 228 |
+
# Return 95th percentile
|
| 229 |
+
return np.percentile(all_distances, 95)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def compute_lesion_level_metrics(gt_volume, pred_volume, iou_threshold=0.1):
|
| 233 |
+
"""
|
| 234 |
+
Compute lesion-level (instance-level) metrics by treating each connected
|
| 235 |
+
component in the GT as an individual lesion.
|
| 236 |
+
|
| 237 |
+
A GT lesion is considered DETECTED if its overlap (IoU) with any single
|
| 238 |
+
predicted component exceeds `iou_threshold`.
|
| 239 |
+
A predicted component is a TRUE POSITIVE if it overlaps any GT lesion
|
| 240 |
+
above threshold, otherwise it is a FALSE POSITIVE lesion.
|
| 241 |
+
|
| 242 |
+
Args:
|
| 243 |
+
gt_volume : binary 3-D numpy array (S, H, W) — ground truth for ONE class
|
| 244 |
+
pred_volume : binary 3-D numpy array (S, H, W) — prediction for ONE class
|
| 245 |
+
iou_threshold: minimum IoU to count a GT lesion as detected (default 0.1)
|
| 246 |
+
|
| 247 |
+
Returns:
|
| 248 |
+
dict with keys:
|
| 249 |
+
n_gt_lesions : total number of GT lesions
|
| 250 |
+
n_pred_lesions : total number of predicted lesion clusters
|
| 251 |
+
tp_lesions : GT lesions that were detected
|
| 252 |
+
fn_lesions : GT lesions that were missed
|
| 253 |
+
fp_lesions : predicted clusters with no GT overlap
|
| 254 |
+
lesion_sensitivity: tp_lesions / n_gt_lesions
|
| 255 |
+
lesion_precision : tp_lesions / n_pred_lesions
|
| 256 |
+
lesion_f1 : harmonic mean of lesion sensitivity and precision
|
| 257 |
+
"""
|
| 258 |
+
gt_bin = gt_volume.astype(bool)
|
| 259 |
+
pred_bin = pred_volume.astype(bool)
|
| 260 |
+
|
| 261 |
+
# Label connected components
|
| 262 |
+
gt_labeled, n_gt = nd_label(gt_bin)
|
| 263 |
+
pred_labeled, n_pred = nd_label(pred_bin)
|
| 264 |
+
|
| 265 |
+
tp_lesions = 0
|
| 266 |
+
detected_pred_ids = set()
|
| 267 |
+
|
| 268 |
+
for gt_id in range(1, n_gt + 1):
|
| 269 |
+
gt_mask = (gt_labeled == gt_id)
|
| 270 |
+
# Find all predicted components that overlap this GT lesion
|
| 271 |
+
overlapping_pred_ids = np.unique(pred_labeled[gt_mask])
|
| 272 |
+
overlapping_pred_ids = overlapping_pred_ids[overlapping_pred_ids > 0]
|
| 273 |
+
|
| 274 |
+
detected = False
|
| 275 |
+
for pred_id in overlapping_pred_ids:
|
| 276 |
+
pred_mask = (pred_labeled == pred_id)
|
| 277 |
+
intersection = np.logical_and(gt_mask, pred_mask).sum()
|
| 278 |
+
union = np.logical_or(gt_mask, pred_mask).sum()
|
| 279 |
+
iou = intersection / (union + 1e-7)
|
| 280 |
+
if iou >= iou_threshold:
|
| 281 |
+
detected = True
|
| 282 |
+
detected_pred_ids.add(pred_id)
|
| 283 |
+
|
| 284 |
+
if detected:
|
| 285 |
+
tp_lesions += 1
|
| 286 |
+
|
| 287 |
+
fn_lesions = n_gt - tp_lesions
|
| 288 |
+
fp_lesions = n_pred - len(detected_pred_ids)
|
| 289 |
+
|
| 290 |
+
lesion_sensitivity = tp_lesions / (n_gt + 1e-7)
|
| 291 |
+
lesion_precision = tp_lesions / (n_pred + 1e-7) if n_pred > 0 else 0.0
|
| 292 |
+
lesion_f1 = (2 * lesion_sensitivity * lesion_precision /
|
| 293 |
+
(lesion_sensitivity + lesion_precision + 1e-7))
|
| 294 |
+
|
| 295 |
+
return {
|
| 296 |
+
'n_gt_lesions' : int(n_gt),
|
| 297 |
+
'n_pred_lesions' : int(n_pred),
|
| 298 |
+
'tp_lesions' : int(tp_lesions),
|
| 299 |
+
'fn_lesions' : int(fn_lesions),
|
| 300 |
+
'fp_lesions' : int(fp_lesions),
|
| 301 |
+
'lesion_sensitivity' : float(lesion_sensitivity),
|
| 302 |
+
'lesion_precision' : float(lesion_precision),
|
| 303 |
+
'lesion_f1' : float(lesion_f1),
|
| 304 |
+
}
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def compute_metrics_from_predictions(y_true, y_pred, num_classes, exclude_class=None):
|
| 308 |
+
"""
|
| 309 |
+
Compute comprehensive metrics from predictions
|
| 310 |
+
|
| 311 |
+
Args:
|
| 312 |
+
y_true: Ground truth class labels (N, H, W)
|
| 313 |
+
y_pred: Predicted class labels (N, H, W)
|
| 314 |
+
num_classes: Number of classes
|
| 315 |
+
exclude_class: Class to exclude from metrics (e.g., 2 for Normal_WMH in 4-class)
|
| 316 |
+
|
| 317 |
+
Returns:
|
| 318 |
+
Dictionary containing metrics
|
| 319 |
+
"""
|
| 320 |
+
# Convert to one-hot
|
| 321 |
+
y_true_onehot = tf.one_hot(y_true, depth=num_classes, dtype=tf.float32)
|
| 322 |
+
y_pred_onehot = tf.one_hot(y_pred, depth=num_classes, dtype=tf.float32)
|
| 323 |
+
|
| 324 |
+
# Flatten spatial dimensions
|
| 325 |
+
y_true_flat = tf.reshape(y_true_onehot, [-1, num_classes])
|
| 326 |
+
y_pred_flat = tf.reshape(y_pred_onehot, [-1, num_classes])
|
| 327 |
+
|
| 328 |
+
# Convert to numpy
|
| 329 |
+
y_true_np = y_true_flat.numpy()
|
| 330 |
+
y_pred_np = y_pred_flat.numpy()
|
| 331 |
+
|
| 332 |
+
metrics = {
|
| 333 |
+
'dice': {},
|
| 334 |
+
'precision': {},
|
| 335 |
+
'recall': {},
|
| 336 |
+
'iou': {},
|
| 337 |
+
'specificity': {},
|
| 338 |
+
'hd95': {},
|
| 339 |
+
'TP': {}
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
classes_to_evaluate = [c for c in range(num_classes) if c != exclude_class]
|
| 343 |
+
|
| 344 |
+
for class_idx in classes_to_evaluate:
|
| 345 |
+
# Extract binary masks for this class
|
| 346 |
+
true_class = y_true_np[:, class_idx]
|
| 347 |
+
pred_class = y_pred_np[:, class_idx]
|
| 348 |
+
|
| 349 |
+
# Compute confusion matrix elements
|
| 350 |
+
TP = np.sum((true_class == 1) & (pred_class == 1))
|
| 351 |
+
FP = np.sum((true_class == 0) & (pred_class == 1))
|
| 352 |
+
FN = np.sum((true_class == 1) & (pred_class == 0))
|
| 353 |
+
TN = np.sum((true_class == 0) & (pred_class == 0))
|
| 354 |
+
|
| 355 |
+
# Dice Score: 2*TP / (2*TP + FP + FN)
|
| 356 |
+
dice = (2 * TP) / (2 * TP + FP + FN + 1e-7)
|
| 357 |
+
|
| 358 |
+
# Precision: TP / (TP + FP)
|
| 359 |
+
precision = TP / (TP + FP + 1e-7)
|
| 360 |
+
|
| 361 |
+
# Recall (Sensitivity): TP / (TP + FN)
|
| 362 |
+
recall = TP / (TP + FN + 1e-7)
|
| 363 |
+
|
| 364 |
+
# IoU (Jaccard): TP / (TP + FP + FN)
|
| 365 |
+
iou = TP / (TP + FP + FN + 1e-7)
|
| 366 |
+
|
| 367 |
+
# Specificity: TN / (TN + FP)
|
| 368 |
+
specificity = TN / (TN + FP + 1e-7)
|
| 369 |
+
|
| 370 |
+
# HD95: Hausdorff Distance 95th percentile
|
| 371 |
+
# Compute on entire volume (all samples combined) for fairness
|
| 372 |
+
true_class_volume = y_true_np[:, class_idx].reshape(y_true.shape[0], y_true.shape[1], y_true.shape[2])
|
| 373 |
+
pred_class_volume = y_pred_np[:, class_idx].reshape(y_pred.shape[0], y_pred.shape[1], y_pred.shape[2])
|
| 374 |
+
|
| 375 |
+
hd95_value = compute_hd95_3d(true_class_volume, pred_class_volume)
|
| 376 |
+
|
| 377 |
+
metrics['dice'][f'class_{class_idx}'] = float(dice)
|
| 378 |
+
metrics['precision'][f'class_{class_idx}'] = float(precision)
|
| 379 |
+
metrics['recall'][f'class_{class_idx}'] = float(recall)
|
| 380 |
+
metrics['iou'][f'class_{class_idx}'] = float(iou)
|
| 381 |
+
metrics['specificity'][f'class_{class_idx}'] = float(specificity)
|
| 382 |
+
metrics['hd95'][f'class_{class_idx}'] = float(hd95_value)
|
| 383 |
+
metrics['TP'][f'class_{class_idx}'] = float(TP)
|
| 384 |
+
|
| 385 |
+
# Compute mean metrics (excluding the excluded class)
|
| 386 |
+
for metric_name in ['dice', 'precision', 'recall', 'iou', 'specificity', 'hd95', 'TP']:
|
| 387 |
+
metrics[metric_name]['mean'] = np.mean([v for v in metrics[metric_name].values()])
|
| 388 |
+
|
| 389 |
+
# --- Lesion-level metrics (connected-component analysis) ---
|
| 390 |
+
metrics['lesion'] = {}
|
| 391 |
+
for class_idx in classes_to_evaluate:
|
| 392 |
+
if class_idx <= 1: # skip background and ventricles
|
| 393 |
+
continue
|
| 394 |
+
true_vol = y_true_np[:, class_idx].reshape(y_true.shape)
|
| 395 |
+
pred_vol = y_pred_np[:, class_idx].reshape(y_pred.shape)
|
| 396 |
+
metrics['lesion'][f'class_{class_idx}'] = compute_lesion_level_metrics(
|
| 397 |
+
true_vol, pred_vol, iou_threshold=0.1
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
return metrics
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
# def aggregate_patient_metrics(per_patient_metrics, num_classes):
|
| 404 |
+
# """
|
| 405 |
+
# Returns both a flat structure (compatible with original overall_metrics)
|
| 406 |
+
# and an extended structure with std/n for richer reporting.
|
| 407 |
+
# """
|
| 408 |
+
# flat_metrics = {m: {} for m in ['dice', 'precision', 'recall', 'iou', 'specificity', 'hd95', 'TP']}
|
| 409 |
+
# rich_metrics = {m: {} for m in ['dice', 'precision', 'recall', 'iou', 'specificity', 'hd95', 'TP']}
|
| 410 |
+
|
| 411 |
+
# metric_names = ['dice', 'precision', 'recall', 'iou', 'specificity', 'hd95', 'TP']
|
| 412 |
+
|
| 413 |
+
# for metric_name in metric_names:
|
| 414 |
+
# for class_idx in range(num_classes):
|
| 415 |
+
# if class_idx == 0: continue
|
| 416 |
+
|
| 417 |
+
# key = f'class_{class_idx}'
|
| 418 |
+
|
| 419 |
+
# values = [
|
| 420 |
+
# per_patient_metrics[pid][metric_name][key]
|
| 421 |
+
# for pid in per_patient_metrics
|
| 422 |
+
# if key in per_patient_metrics[pid][metric_name]
|
| 423 |
+
# and not np.isnan(per_patient_metrics[pid][metric_name][key])
|
| 424 |
+
# ]
|
| 425 |
+
|
| 426 |
+
# TP_values = [
|
| 427 |
+
# per_patient_metrics[pid]['TP'][key]
|
| 428 |
+
# for pid in per_patient_metrics
|
| 429 |
+
# if key in per_patient_metrics[pid]['TP']
|
| 430 |
+
# and not np.isnan(per_patient_metrics[pid]['TP'][key])
|
| 431 |
+
# ]
|
| 432 |
+
|
| 433 |
+
# weighted_mean_values = np.sum((np.array(values) * np.array(TP_values)) / np.sum(np.array(TP_values)))
|
| 434 |
+
|
| 435 |
+
# mean_val = float(np.mean(values)) if values else np.nan
|
| 436 |
+
# std_val = float(np.std(values)) if values else np.nan
|
| 437 |
+
|
| 438 |
+
# # Flat: backward compatible with all existing print/save code
|
| 439 |
+
# flat_metrics[metric_name][key] = weighted_mean_values if metric_name != 'hd95' else mean_val
|
| 440 |
+
|
| 441 |
+
# # Rich: for extended reporting
|
| 442 |
+
# rich_metrics[metric_name][key] = {
|
| 443 |
+
# 'mean': mean_val,
|
| 444 |
+
# 'std': std_val,
|
| 445 |
+
# 'n': len(values)
|
| 446 |
+
# }
|
| 447 |
+
|
| 448 |
+
# # Mean across classes — same for both
|
| 449 |
+
# class_means = [
|
| 450 |
+
# flat_metrics[metric_name][f'class_{c}']
|
| 451 |
+
# for c in range(num_classes)
|
| 452 |
+
# if c!=0 and not np.isnan(flat_metrics[metric_name][f'class_{c}'])
|
| 453 |
+
# ]
|
| 454 |
+
# mean_across_classes = float(np.mean(class_means)) if class_means else np.nan
|
| 455 |
+
# flat_metrics[metric_name]['mean'] = mean_across_classes
|
| 456 |
+
# rich_metrics[metric_name]['mean'] = mean_across_classes
|
| 457 |
+
|
| 458 |
+
# return flat_metrics, rich_metrics
|
| 459 |
+
|
| 460 |
+
def aggregate_patient_metrics(per_patient_metrics, num_classes):
|
| 461 |
+
"""
|
| 462 |
+
Returns both a flat structure (compatible with original overall_metrics)
|
| 463 |
+
and an extended structure with std/n for richer reporting.
|
| 464 |
+
|
| 465 |
+
Includes lesion-level metrics (connected-component analysis):
|
| 466 |
+
- lesion_sensitivity : mean across patients of (tp_lesions / n_gt_lesions)
|
| 467 |
+
- lesion_precision : mean across patients of (tp_lesions / n_pred_lesions)
|
| 468 |
+
- lesion_f1 : mean across patients of harmonic mean of the above
|
| 469 |
+
- n_gt_lesions : total GT lesions summed across all patients
|
| 470 |
+
- n_pred_lesions : total predicted lesion clusters summed across all patients
|
| 471 |
+
- tp_lesions : total TP lesions summed across all patients
|
| 472 |
+
- fn_lesions : total FN lesions summed across all patients
|
| 473 |
+
- fp_lesions : total FP lesions summed across all patients
|
| 474 |
+
"""
|
| 475 |
+
# ── Voxel-level metrics (unchanged) ─────────────────────────────────────
|
| 476 |
+
voxel_metric_names = ['dice', 'precision', 'recall', 'iou', 'specificity', 'hd95', 'TP']
|
| 477 |
+
flat_metrics = {m: {} for m in voxel_metric_names}
|
| 478 |
+
rich_metrics = {m: {} for m in voxel_metric_names}
|
| 479 |
+
|
| 480 |
+
for metric_name in voxel_metric_names:
|
| 481 |
+
for class_idx in range(num_classes):
|
| 482 |
+
if class_idx == 0:
|
| 483 |
+
continue
|
| 484 |
+
|
| 485 |
+
key = f'class_{class_idx}'
|
| 486 |
+
|
| 487 |
+
values = [
|
| 488 |
+
per_patient_metrics[pid][metric_name][key]
|
| 489 |
+
for pid in per_patient_metrics
|
| 490 |
+
if key in per_patient_metrics[pid][metric_name]
|
| 491 |
+
and not np.isnan(per_patient_metrics[pid][metric_name][key])
|
| 492 |
+
]
|
| 493 |
+
|
| 494 |
+
TP_values = [
|
| 495 |
+
per_patient_metrics[pid]['TP'][key]
|
| 496 |
+
for pid in per_patient_metrics
|
| 497 |
+
if key in per_patient_metrics[pid]['TP']
|
| 498 |
+
and not np.isnan(per_patient_metrics[pid]['TP'][key])
|
| 499 |
+
]
|
| 500 |
+
|
| 501 |
+
weighted_mean_values = np.sum(
|
| 502 |
+
(np.array(values) * np.array(TP_values)) / np.sum(np.array(TP_values))
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
mean_val = float(np.mean(values)) if values else np.nan
|
| 506 |
+
std_val = float(np.std(values)) if values else np.nan
|
| 507 |
+
|
| 508 |
+
flat_metrics[metric_name][key] = weighted_mean_values if metric_name != 'hd95' else mean_val
|
| 509 |
+
rich_metrics[metric_name][key] = {
|
| 510 |
+
'mean': mean_val,
|
| 511 |
+
'std': std_val,
|
| 512 |
+
'n': len(values)
|
| 513 |
+
}
|
| 514 |
+
|
| 515 |
+
# Mean across classes
|
| 516 |
+
class_means = [
|
| 517 |
+
flat_metrics[metric_name][f'class_{c}']
|
| 518 |
+
for c in range(num_classes)
|
| 519 |
+
if c != 0 and not np.isnan(flat_metrics[metric_name][f'class_{c}'])
|
| 520 |
+
]
|
| 521 |
+
mean_across_classes = float(np.mean(class_means)) if class_means else np.nan
|
| 522 |
+
flat_metrics[metric_name]['mean'] = mean_across_classes
|
| 523 |
+
rich_metrics[metric_name]['mean'] = mean_across_classes
|
| 524 |
+
|
| 525 |
+
# ── Lesion-level metrics (new) ───────────────────────────────────────────
|
| 526 |
+
# Scalar fields: averaged across patients (mean ± std)
|
| 527 |
+
lesion_scalar_keys = ['lesion_sensitivity', 'lesion_precision', 'lesion_f1']
|
| 528 |
+
# Count fields: summed across patients (total pool)
|
| 529 |
+
lesion_count_keys = ['n_gt_lesions', 'n_pred_lesions', 'tp_lesions', 'fn_lesions', 'fp_lesions']
|
| 530 |
+
|
| 531 |
+
flat_metrics['lesion'] = {}
|
| 532 |
+
rich_metrics['lesion'] = {}
|
| 533 |
+
|
| 534 |
+
for class_idx in range(num_classes):
|
| 535 |
+
if class_idx <= 1: # skip background and ventricles
|
| 536 |
+
continue
|
| 537 |
+
|
| 538 |
+
key = f'class_{class_idx}'
|
| 539 |
+
flat_metrics['lesion'][key] = {}
|
| 540 |
+
rich_metrics['lesion'][key] = {}
|
| 541 |
+
|
| 542 |
+
# --- Scalar metrics: mean ± std across patients ---
|
| 543 |
+
for sk in lesion_scalar_keys:
|
| 544 |
+
vals = [
|
| 545 |
+
per_patient_metrics[pid]['lesion'][key][sk]
|
| 546 |
+
for pid in per_patient_metrics
|
| 547 |
+
if 'lesion' in per_patient_metrics[pid]
|
| 548 |
+
and key in per_patient_metrics[pid]['lesion']
|
| 549 |
+
]
|
| 550 |
+
mean_val = float(np.mean(vals)) if vals else np.nan
|
| 551 |
+
std_val = float(np.std(vals)) if vals else np.nan
|
| 552 |
+
flat_metrics['lesion'][key][sk] = mean_val
|
| 553 |
+
rich_metrics['lesion'][key][sk] = {
|
| 554 |
+
'mean': mean_val,
|
| 555 |
+
'std': std_val,
|
| 556 |
+
'n': len(vals)
|
| 557 |
+
}
|
| 558 |
+
|
| 559 |
+
# --- Count metrics: sum across patients ---
|
| 560 |
+
for ck in lesion_count_keys:
|
| 561 |
+
vals = [
|
| 562 |
+
per_patient_metrics[pid]['lesion'][key][ck]
|
| 563 |
+
for pid in per_patient_metrics
|
| 564 |
+
if 'lesion' in per_patient_metrics[pid]
|
| 565 |
+
and key in per_patient_metrics[pid]['lesion']
|
| 566 |
+
]
|
| 567 |
+
flat_metrics['lesion'][key][ck] = int(np.sum(vals)) if vals else 0
|
| 568 |
+
rich_metrics['lesion'][key][ck] = int(np.sum(vals)) if vals else 0
|
| 569 |
+
|
| 570 |
+
# Mean lesion scalars across foreground classes
|
| 571 |
+
for sk in lesion_scalar_keys:
|
| 572 |
+
class_vals = [
|
| 573 |
+
flat_metrics['lesion'][f'class_{c}'][sk]
|
| 574 |
+
for c in range(num_classes)
|
| 575 |
+
if c > 1 and not np.isnan(flat_metrics['lesion'][f'class_{c}'][sk])
|
| 576 |
+
]
|
| 577 |
+
mean_across = float(np.mean(class_vals)) if class_vals else np.nan
|
| 578 |
+
flat_metrics['lesion'][f'mean_{sk}'] = mean_across
|
| 579 |
+
rich_metrics['lesion'][f'mean_{sk}'] = mean_across
|
| 580 |
+
|
| 581 |
+
# Summed counts across foreground classes
|
| 582 |
+
for ck in lesion_count_keys:
|
| 583 |
+
flat_metrics['lesion'][f'total_{ck}'] = int(np.sum([
|
| 584 |
+
flat_metrics['lesion'][f'class_{c}'][ck]
|
| 585 |
+
for c in range(num_classes) if c > 1
|
| 586 |
+
]))
|
| 587 |
+
rich_metrics['lesion'][f'total_{ck}'] = flat_metrics['lesion'][f'total_{ck}']
|
| 588 |
+
|
| 589 |
+
return flat_metrics, rich_metrics
|
| 590 |
+
|
| 591 |
+
|
| 592 |
+
###################### Original Visualization Functions ######################
|
| 593 |
+
|
| 594 |
+
def visualize_prediction(flair, ground_truth, prediction,
|
| 595 |
+
probability_map, save_path,
|
| 596 |
+
sample_id, num_classes):
|
| 597 |
+
"""
|
| 598 |
+
Create comprehensive visualization of prediction
|
| 599 |
+
|
| 600 |
+
Args:
|
| 601 |
+
flair: Input FLAIR image (H, W)
|
| 602 |
+
ground_truth: Ground truth mask (H, W)
|
| 603 |
+
prediction: Predicted mask (H, W)
|
| 604 |
+
probability_map: Max probability map (H, W)
|
| 605 |
+
save_path: Path to save figure
|
| 606 |
+
sample_id: Sample identifier
|
| 607 |
+
num_classes: Number of classes
|
| 608 |
+
"""
|
| 609 |
+
fig, axes = plt.subplots(2, 3, figsize=(18, 12))
|
| 610 |
+
|
| 611 |
+
# Input FLAIR
|
| 612 |
+
axes[0, 0].imshow(flair, cmap='gray')
|
| 613 |
+
axes[0, 0].set_title('Input FLAIR', fontsize=14, fontweight='bold')
|
| 614 |
+
axes[0, 0].axis('off')
|
| 615 |
+
|
| 616 |
+
# Ground truth
|
| 617 |
+
im1 = axes[0, 1].imshow(ground_truth, cmap='jet', vmin=0, vmax=num_classes-1)
|
| 618 |
+
axes[0, 1].set_title('Ground Truth', fontsize=14, fontweight='bold')
|
| 619 |
+
axes[0, 1].axis('off')
|
| 620 |
+
plt.colorbar(im1, ax=axes[0, 1], fraction=0.046, pad=0.04)
|
| 621 |
+
|
| 622 |
+
# Prediction
|
| 623 |
+
im2 = axes[0, 2].imshow(prediction, cmap='jet', vmin=0, vmax=num_classes-1)
|
| 624 |
+
axes[0, 2].set_title('Prediction', fontsize=14, fontweight='bold')
|
| 625 |
+
axes[0, 2].axis('off')
|
| 626 |
+
plt.colorbar(im2, ax=axes[0, 2], fraction=0.046, pad=0.04)
|
| 627 |
+
|
| 628 |
+
# Max probability
|
| 629 |
+
im3 = axes[1, 0].imshow(probability_map, cmap='viridis', vmin=0, vmax=1)
|
| 630 |
+
axes[1, 0].set_title('Prediction Confidence', fontsize=14, fontweight='bold')
|
| 631 |
+
axes[1, 0].axis('off')
|
| 632 |
+
plt.colorbar(im3, ax=axes[1, 0], fraction=0.046, pad=0.04)
|
| 633 |
+
|
| 634 |
+
# Error map
|
| 635 |
+
error_map = (prediction != ground_truth).astype(float)
|
| 636 |
+
im4 = axes[1, 1].imshow(error_map, cmap='Reds', vmin=0, vmax=1)
|
| 637 |
+
axes[1, 1].set_title('Error Map (Red=Wrong)', fontsize=14, fontweight='bold')
|
| 638 |
+
axes[1, 1].axis('off')
|
| 639 |
+
plt.colorbar(im4, ax=axes[1, 1], fraction=0.046, pad=0.04)
|
| 640 |
+
|
| 641 |
+
# Overlay: FLAIR + Prediction contours
|
| 642 |
+
axes[1, 2].imshow(flair, cmap='gray')
|
| 643 |
+
# Create contours for each class
|
| 644 |
+
from scipy import ndimage
|
| 645 |
+
for class_idx in range(1, num_classes): # Skip background
|
| 646 |
+
class_mask = (prediction == class_idx)
|
| 647 |
+
contours = class_mask ^ ndimage.binary_erosion(class_mask)
|
| 648 |
+
if np.any(contours):
|
| 649 |
+
axes[1, 2].contour(contours, colors=[plt.cm.jet(class_idx/(num_classes-1))], linewidths=1.5)
|
| 650 |
+
axes[1, 2].set_title('FLAIR + Prediction Overlay', fontsize=14, fontweight='bold')
|
| 651 |
+
axes[1, 2].axis('off')
|
| 652 |
+
|
| 653 |
+
plt.suptitle(f'Sample: {sample_id}', fontsize=16, fontweight='bold', y=0.98)
|
| 654 |
+
plt.tight_layout()
|
| 655 |
+
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
| 656 |
+
plt.close()
|
| 657 |
+
|
| 658 |
+
|
| 659 |
+
def visualize_prediction_short(flair, ground_truth, prediction,
|
| 660 |
+
probability_map, save_path,
|
| 661 |
+
sample_id, num_classes):
|
| 662 |
+
"""
|
| 663 |
+
Create comprehensive visualization of prediction
|
| 664 |
+
|
| 665 |
+
Args:
|
| 666 |
+
flair: Input FLAIR image (H, W)
|
| 667 |
+
ground_truth: Ground truth mask (H, W)
|
| 668 |
+
prediction: Predicted mask (H, W)
|
| 669 |
+
probability_map: Max probability map (H, W)
|
| 670 |
+
save_path: Path to save figure
|
| 671 |
+
sample_id: Sample identifier
|
| 672 |
+
num_classes: Number of classes
|
| 673 |
+
"""
|
| 674 |
+
fig, axes = plt.subplots(2, 1, figsize=(6, 12))
|
| 675 |
+
|
| 676 |
+
cmap = plt.cm.jet
|
| 677 |
+
flair_norm = (flair - flair.min()) / (flair.max() - flair.min() + 1e-8)
|
| 678 |
+
flair_rgb = np.stack([flair_norm] * 3, axis=-1)
|
| 679 |
+
|
| 680 |
+
for ax, mask, title in zip(axes, [ground_truth, prediction], ['Ground Truth Overlay', 'Prediction Overlay']):
|
| 681 |
+
mask_rgb = cmap(mask / (num_classes - 1))[..., :3] # (H, W, 3)
|
| 682 |
+
foreground = mask > 0
|
| 683 |
+
alpha = np.where(foreground, 0.6, 0.0)[..., np.newaxis] # fade non-background
|
| 684 |
+
blended = flair_rgb * (1 - alpha) + mask_rgb * alpha
|
| 685 |
+
|
| 686 |
+
ax.imshow(blended)
|
| 687 |
+
# ax.set_title(title, fontsize=14, fontweight='bold')
|
| 688 |
+
ax.axis('off')
|
| 689 |
+
|
| 690 |
+
# Shared colorbar
|
| 691 |
+
sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=0, vmax=num_classes - 1))
|
| 692 |
+
sm.set_array([])
|
| 693 |
+
# fig.colorbar(sm, ax=axes.ravel().tolist(), fraction=0.02, pad=0.04)
|
| 694 |
+
|
| 695 |
+
# plt.suptitle(f'Sample: {sample_id}', fontsize=16, fontweight='bold')
|
| 696 |
+
plt.tight_layout()
|
| 697 |
+
try:
|
| 698 |
+
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
| 699 |
+
except:
|
| 700 |
+
print(f"\n Unsaved image: {save_path}")
|
| 701 |
+
plt.close()
|
| 702 |
+
|
| 703 |
+
|
| 704 |
+
def save_prediction_as_nifti(prediction, save_path, reference_nifti=None):
|
| 705 |
+
"""
|
| 706 |
+
Save prediction as NIfTI file
|
| 707 |
+
|
| 708 |
+
Args:
|
| 709 |
+
prediction: Prediction array (H, W) or (H, W, D)
|
| 710 |
+
save_path: Path to save NIfTI file
|
| 711 |
+
reference_nifti: Optional reference NIfTI for header info
|
| 712 |
+
"""
|
| 713 |
+
if reference_nifti is not None:
|
| 714 |
+
# Use reference header
|
| 715 |
+
nifti_img = nib.Nifti1Image(prediction.astype(np.uint8), reference_nifti.affine, reference_nifti.header)
|
| 716 |
+
else:
|
| 717 |
+
# Create new NIfTI with identity affine
|
| 718 |
+
nifti_img = nib.Nifti1Image(prediction.astype(np.uint8), np.eye(4))
|
| 719 |
+
|
| 720 |
+
nib.save(nifti_img, save_path)
|
| 721 |
+
|
| 722 |
+
|
| 723 |
+
###################### Post-processing Function ######################
|
| 724 |
+
|
| 725 |
+
def post_process_pred(pred_classes, num_classes=3, min_object_size=5, closing_kernel_size=2):
|
| 726 |
+
"""
|
| 727 |
+
Post-process a single 2-D multi-class prediction slice.
|
| 728 |
+
|
| 729 |
+
Input
|
| 730 |
+
-----
|
| 731 |
+
pred_classes : np.ndarray of shape (H, W) — integer class labels
|
| 732 |
+
produced by tf.argmax(...).numpy()[0] inside the
|
| 733 |
+
inference loop (one slice at a time).
|
| 734 |
+
num_classes : 3 → classes are 0=BG, 1=Vent, 2=AbWMH
|
| 735 |
+
4 → classes are 0=BG, 1=Vent, 2=NormWMH, 3=AbWMH
|
| 736 |
+
min_object_size : connected components smaller than this (pixels) are
|
| 737 |
+
removed after morphological cleaning. Default 5.
|
| 738 |
+
closing_kernel_size: radius of the disk used for binary_closing. Default 2.
|
| 739 |
+
|
| 740 |
+
Output
|
| 741 |
+
------
|
| 742 |
+
post_pred : np.ndarray of shape (H, W), same dtype as pred_classes,
|
| 743 |
+
with cleaned and overlap-resolved integer class labels.
|
| 744 |
+
|
| 745 |
+
Processing pipeline (per class)
|
| 746 |
+
--------------------------------
|
| 747 |
+
1. Extract binary mask for each foreground class from the label map.
|
| 748 |
+
2. Apply binary_closing → fill small holes / bridge tiny gaps.
|
| 749 |
+
3. Apply remove_small_objects → discard isolated noise specks.
|
| 750 |
+
4. Resolve overlaps by anatomical priority:
|
| 751 |
+
Ventricles > Normal WMH > Abnormal WMH
|
| 752 |
+
(a higher-priority class always wins contested pixels)
|
| 753 |
+
5. Reconstruct the integer label map from the cleaned binary masks.
|
| 754 |
+
"""
|
| 755 |
+
from skimage.morphology import remove_small_objects, binary_erosion, binary_closing, disk, binary_dilation
|
| 756 |
+
|
| 757 |
+
kernel = disk(closing_kernel_size)
|
| 758 |
+
|
| 759 |
+
def clean(mask):
|
| 760 |
+
"""Apply closing + small-object removal to a single binary mask."""
|
| 761 |
+
if not mask.any():
|
| 762 |
+
return mask
|
| 763 |
+
mask = binary_closing(mask, kernel)
|
| 764 |
+
# mask = binary_erosion(mask, disk(1))
|
| 765 |
+
mask = remove_small_objects(mask, min_size=min_object_size)
|
| 766 |
+
return mask
|
| 767 |
+
|
| 768 |
+
# ── 1. Extract per-class binary masks from the 2-D label map ────────────
|
| 769 |
+
vent_mask = (pred_classes == 1)
|
| 770 |
+
|
| 771 |
+
if num_classes == 4:
|
| 772 |
+
nwmh_mask = (pred_classes == 2)
|
| 773 |
+
abwmh_mask = (pred_classes == 3)
|
| 774 |
+
else:
|
| 775 |
+
# 3-class scenario: no Normal WMH, AbWMH is class 2
|
| 776 |
+
nwmh_mask = np.zeros_like(vent_mask)
|
| 777 |
+
abwmh_mask = (pred_classes == 2)
|
| 778 |
+
|
| 779 |
+
# ── 2-3. Morphological cleaning per class ───────────────────────────────
|
| 780 |
+
vent_mask = clean(vent_mask)
|
| 781 |
+
nwmh_mask = clean(nwmh_mask)
|
| 782 |
+
abwmh_mask = clean(abwmh_mask)
|
| 783 |
+
|
| 784 |
+
# ── 4. Resolve overlaps: higher-priority mask wins ───────────────────────
|
| 785 |
+
# Ventricles > Normal WMH > Abnormal WMH
|
| 786 |
+
nwmh_mask = nwmh_mask & ~vent_mask # NormWMH cannot overlap Vent
|
| 787 |
+
abwmh_mask = abwmh_mask & ~vent_mask # AbWMH cannot overlap Vent
|
| 788 |
+
abwmh_mask = abwmh_mask & ~nwmh_mask # AbWMH cannot overlap NormWMH
|
| 789 |
+
|
| 790 |
+
# ── 5. Reconstruct the integer label map ─────────────────────────────────
|
| 791 |
+
post_pred = np.zeros_like(pred_classes) # background = 0
|
| 792 |
+
post_pred[vent_mask] = 1
|
| 793 |
+
|
| 794 |
+
if num_classes == 4:
|
| 795 |
+
post_pred[nwmh_mask] = 2
|
| 796 |
+
post_pred[abwmh_mask] = 3
|
| 797 |
+
else:
|
| 798 |
+
post_pred[abwmh_mask] = 2
|
| 799 |
+
|
| 800 |
+
return post_pred
|
| 801 |
+
|
| 802 |
+
|
| 803 |
+
###################### Main Inference Function ######################
|
| 804 |
+
|
| 805 |
+
def run_inference(config: InferenceConfig):
|
| 806 |
+
"""
|
| 807 |
+
Main inference function
|
| 808 |
+
|
| 809 |
+
Args:
|
| 810 |
+
config: InferenceConfig object
|
| 811 |
+
|
| 812 |
+
Returns:
|
| 813 |
+
Dictionary containing all predictions and metrics
|
| 814 |
+
"""
|
| 815 |
+
print("\n" + "="*70)
|
| 816 |
+
print(f"RUNNING INFERENCE")
|
| 817 |
+
print("="*70)
|
| 818 |
+
|
| 819 |
+
# Initialize data loader
|
| 820 |
+
data_config = DataConfig()
|
| 821 |
+
data_loader = P2DataLoader(data_config)
|
| 822 |
+
|
| 823 |
+
# Load test dataset
|
| 824 |
+
print("Loading test data...")
|
| 825 |
+
test_dataset = data_loader.create_dataset_for_fold(
|
| 826 |
+
fold_id=config.fold_id,
|
| 827 |
+
split='test',
|
| 828 |
+
preprocessing=config.preprocessing,
|
| 829 |
+
class_scenario=config.class_scenario,
|
| 830 |
+
batch_size=config.batch_size,
|
| 831 |
+
shuffle=False
|
| 832 |
+
)
|
| 833 |
+
|
| 834 |
+
# Get dataset size
|
| 835 |
+
test_size = tf.data.experimental.cardinality(test_dataset).numpy()
|
| 836 |
+
if test_size < 0:
|
| 837 |
+
test_size = sum(1 for _ in test_dataset)
|
| 838 |
+
test_dataset = data_loader.create_dataset_for_fold(
|
| 839 |
+
fold_id=config.fold_id, split='test',
|
| 840 |
+
preprocessing=config.preprocessing,
|
| 841 |
+
class_scenario=config.class_scenario,
|
| 842 |
+
batch_size=config.batch_size, shuffle=False
|
| 843 |
+
)
|
| 844 |
+
|
| 845 |
+
print(f"Test samples: {test_size}\n")
|
| 846 |
+
|
| 847 |
+
# Load model
|
| 848 |
+
print(f"Loading model from: {config.model_path}")
|
| 849 |
+
try:
|
| 850 |
+
if config.architecture_name == 'unet':
|
| 851 |
+
from unet_model import build_unet_3class as build_specific_3class # must be updated with the actual used model for traininig
|
| 852 |
+
elif config.architecture_name == 'attnunet':
|
| 853 |
+
from attn_unet_model import build_attention_unet_3class as build_specific_3class
|
| 854 |
+
elif config.architecture_name == 'dlv3unet':
|
| 855 |
+
from dlv3_unet_model_GN import build_deeplabv3_unet_3class as build_specific_3class
|
| 856 |
+
elif config.architecture_name == 'transunet':
|
| 857 |
+
from trans_unet_model import build_trans_unet_3class as build_specific_3class
|
| 858 |
+
else:
|
| 859 |
+
print(f"❌ Error loading model: Invalid Model Name")
|
| 860 |
+
raise
|
| 861 |
+
|
| 862 |
+
# Build model architecture first
|
| 863 |
+
generator = build_specific_3class(
|
| 864 |
+
input_shape=(256, 256, 1),
|
| 865 |
+
num_classes=config.num_classes
|
| 866 |
+
)
|
| 867 |
+
|
| 868 |
+
# Load weights
|
| 869 |
+
generator.load_weights(str(config.model_path))
|
| 870 |
+
print("✅ Model loaded successfully\n")
|
| 871 |
+
|
| 872 |
+
except Exception as e:
|
| 873 |
+
print(f"❌ Error loading model: {e}")
|
| 874 |
+
raise
|
| 875 |
+
|
| 876 |
+
# Initialize storage - keyed by patient ID
|
| 877 |
+
patient_results = defaultdict(lambda: {
|
| 878 |
+
'predictions': [],
|
| 879 |
+
'ground_truths': [],
|
| 880 |
+
'probabilities': [],
|
| 881 |
+
'flairs': [],
|
| 882 |
+
'slice_indices': []
|
| 883 |
+
})
|
| 884 |
+
sample_ids = []
|
| 885 |
+
|
| 886 |
+
# Run inference
|
| 887 |
+
print("Running inference on test set...")
|
| 888 |
+
test_bar = tqdm(test_dataset, total=test_size, desc="Inference")
|
| 889 |
+
|
| 890 |
+
for idx, (paired_input, target_mask, patient_id_tensor, slice_num_tensor) in enumerate(test_bar):
|
| 891 |
+
|
| 892 |
+
patient_id = patient_id_tensor.numpy()[0].decode('utf-8') # batch dim + bytes→str
|
| 893 |
+
slice_num = int(slice_num_tensor.numpy()[0])
|
| 894 |
+
|
| 895 |
+
sample_ids.append(f"{patient_id}_slice_{slice_num:03d}")
|
| 896 |
+
|
| 897 |
+
# Prepare input
|
| 898 |
+
flair_normalized = prepare_input(paired_input)
|
| 899 |
+
|
| 900 |
+
# Generate prediction
|
| 901 |
+
prediction_softmax = generator(flair_normalized, training=False)
|
| 902 |
+
|
| 903 |
+
# Convert to class labels
|
| 904 |
+
pred_classes = tf.argmax(prediction_softmax, axis=-1).numpy()[0]
|
| 905 |
+
max_prob = tf.reduce_max(prediction_softmax, axis=-1).numpy()[0]
|
| 906 |
+
ground_truth = target_mask.numpy()[0]
|
| 907 |
+
flair = flair_normalized.numpy()[0, :, :, 0]
|
| 908 |
+
|
| 909 |
+
# Post-process the predictions
|
| 910 |
+
# pred_classes_post = post_process_pred(pred_classes, num_classes=config.num_classes)
|
| 911 |
+
|
| 912 |
+
# Store per-patient
|
| 913 |
+
patient_results[patient_id]['predictions'].append(pred_classes)
|
| 914 |
+
patient_results[patient_id]['ground_truths'].append(ground_truth)
|
| 915 |
+
patient_results[patient_id]['probabilities'].append(max_prob)
|
| 916 |
+
patient_results[patient_id]['flairs'].append(flair)
|
| 917 |
+
patient_results[patient_id]['slice_indices'].append(slice_num)
|
| 918 |
+
|
| 919 |
+
# Create visualization
|
| 920 |
+
if idx % 10 == 0 or True: # Visualize every 10th sample
|
| 921 |
+
# viz_path = config.visualizations_dir / f"visualization_{idx:04d}.png"
|
| 922 |
+
viz_path = config.visualizations_dir / f"{sample_ids[-1]}.png"
|
| 923 |
+
visualize_prediction_short(
|
| 924 |
+
flair, ground_truth, pred_classes,
|
| 925 |
+
max_prob, viz_path,
|
| 926 |
+
sample_ids[-1], config.num_classes
|
| 927 |
+
)
|
| 928 |
+
|
| 929 |
+
print("\n✅ Inference complete!\n")
|
| 930 |
+
|
| 931 |
+
# Compute overall metrics
|
| 932 |
+
print("Computing metrics...")
|
| 933 |
+
exclude_class = None
|
| 934 |
+
per_patient_metrics = {}
|
| 935 |
+
|
| 936 |
+
for patient_id, data in patient_results.items():
|
| 937 |
+
# Sort slices by anatomical order
|
| 938 |
+
order = np.argsort(data['slice_indices'])
|
| 939 |
+
|
| 940 |
+
gt_volume = np.array(data['ground_truths'])[order] # (S, H, W)
|
| 941 |
+
pred_volume = np.array(data['predictions'])[order] # (S, H, W)
|
| 942 |
+
|
| 943 |
+
per_patient_metrics[patient_id] = compute_metrics_from_predictions(
|
| 944 |
+
gt_volume,
|
| 945 |
+
pred_volume,
|
| 946 |
+
config.num_classes
|
| 947 |
+
)
|
| 948 |
+
print(f"\nPatint_id : {patient_id} , Stats: {per_patient_metrics[patient_id]}\n")
|
| 949 |
+
|
| 950 |
+
pm = per_patient_metrics[patient_id]
|
| 951 |
+
print(f"\nPatient_id: {patient_id}")
|
| 952 |
+
print(f" Voxel — Dice: { {k: round(v,4) for k,v in pm['dice'].items()} }")
|
| 953 |
+
if 'lesion' in pm:
|
| 954 |
+
for cls, ld in pm['lesion'].items():
|
| 955 |
+
print(f" Lesion [{cls}] — "
|
| 956 |
+
f"GT:{ld['n_gt_lesions']} Pred:{ld['n_pred_lesions']} "
|
| 957 |
+
f"TP:{ld['tp_lesions']} FP:{ld['fp_lesions']} FN:{ld['fn_lesions']} "
|
| 958 |
+
f"Sens:{ld['lesion_sensitivity']:.3f} Prec:{ld['lesion_precision']:.3f} "
|
| 959 |
+
f"F1:{ld['lesion_f1']:.3f}")
|
| 960 |
+
|
| 961 |
+
# Aggregate across patients
|
| 962 |
+
overall_metrics, overall_metrics_rich = aggregate_patient_metrics(
|
| 963 |
+
per_patient_metrics, config.num_classes
|
| 964 |
+
)
|
| 965 |
+
# overall_metrics → drop-in replacement for old overall_metrics, all print/save code unchanged
|
| 966 |
+
# overall_metrics_rich → use wherever we want mean ± std reporting
|
| 967 |
+
|
| 968 |
+
# Print standard metrics
|
| 969 |
+
print("\n" + "="*70)
|
| 970 |
+
print("STANDARD METRICS (Class vs Rest)")
|
| 971 |
+
print("="*70)
|
| 972 |
+
|
| 973 |
+
print("\nClass-wise Dice Scores:")
|
| 974 |
+
for class_idx, class_name in enumerate(config.class_names):
|
| 975 |
+
if exclude_class is not None and class_idx == exclude_class:
|
| 976 |
+
continue
|
| 977 |
+
key = f'class_{class_idx}'
|
| 978 |
+
if key in overall_metrics['dice']:
|
| 979 |
+
print(f" {class_name}: {overall_metrics['dice'][key]:.4f}")
|
| 980 |
+
print(f" Mean Dice: {overall_metrics['dice']['mean']:.4f}")
|
| 981 |
+
|
| 982 |
+
print("\nClass-wise Precision:")
|
| 983 |
+
for class_idx, class_name in enumerate(config.class_names):
|
| 984 |
+
if exclude_class is not None and class_idx == exclude_class:
|
| 985 |
+
continue
|
| 986 |
+
key = f'class_{class_idx}'
|
| 987 |
+
if key in overall_metrics['precision']:
|
| 988 |
+
print(f" {class_name}: {overall_metrics['precision'][key]:.4f}")
|
| 989 |
+
print(f" Mean Precision: {overall_metrics['precision']['mean']:.4f}")
|
| 990 |
+
|
| 991 |
+
print("\nClass-wise Recall:")
|
| 992 |
+
for class_idx, class_name in enumerate(config.class_names):
|
| 993 |
+
if exclude_class is not None and class_idx == exclude_class:
|
| 994 |
+
continue
|
| 995 |
+
key = f'class_{class_idx}'
|
| 996 |
+
if key in overall_metrics['recall']:
|
| 997 |
+
print(f" {class_name}: {overall_metrics['recall'][key]:.4f}")
|
| 998 |
+
print(f" Mean Recall: {overall_metrics['recall']['mean']:.4f}")
|
| 999 |
+
|
| 1000 |
+
print("\nClass-wise IoU:")
|
| 1001 |
+
for class_idx, class_name in enumerate(config.class_names):
|
| 1002 |
+
if exclude_class is not None and class_idx == exclude_class:
|
| 1003 |
+
continue
|
| 1004 |
+
key = f'class_{class_idx}'
|
| 1005 |
+
if key in overall_metrics['iou']:
|
| 1006 |
+
print(f" {class_name}: {overall_metrics['iou'][key]:.4f}")
|
| 1007 |
+
print(f" Mean IoU: {overall_metrics['iou']['mean']:.4f}")
|
| 1008 |
+
|
| 1009 |
+
print("\nClass-wise Specificity:")
|
| 1010 |
+
for class_idx, class_name in enumerate(config.class_names):
|
| 1011 |
+
if exclude_class is not None and class_idx == exclude_class:
|
| 1012 |
+
continue
|
| 1013 |
+
key = f'class_{class_idx}'
|
| 1014 |
+
if key in overall_metrics['specificity']:
|
| 1015 |
+
print(f" {class_name}: {overall_metrics['specificity'][key]:.4f}")
|
| 1016 |
+
print(f" Mean Specificity: {overall_metrics['specificity']['mean']:.4f}")
|
| 1017 |
+
|
| 1018 |
+
print("\nClass-wise HD95 (lower is better):")
|
| 1019 |
+
for class_idx, class_name in enumerate(config.class_names):
|
| 1020 |
+
if exclude_class is not None and class_idx == exclude_class:
|
| 1021 |
+
continue
|
| 1022 |
+
key = f'class_{class_idx}'
|
| 1023 |
+
if key in overall_metrics['hd95']:
|
| 1024 |
+
print(f" {class_name}: {overall_metrics['hd95'][key]:.4f}")
|
| 1025 |
+
print(f" Mean HD95: {overall_metrics['hd95']['mean']:.4f}")
|
| 1026 |
+
|
| 1027 |
+
print("="*70 + "\n")
|
| 1028 |
+
|
| 1029 |
+
# Print lesion-level metrics
|
| 1030 |
+
print("\n" + "="*70)
|
| 1031 |
+
print("LESION-LEVEL METRICS (Connected-Component Analysis)")
|
| 1032 |
+
print("="*70)
|
| 1033 |
+
|
| 1034 |
+
for class_idx, class_name in enumerate(config.class_names):
|
| 1035 |
+
if class_idx == 0:
|
| 1036 |
+
continue
|
| 1037 |
+
key = f'class_{class_idx}'
|
| 1038 |
+
if key not in overall_metrics.get('lesion', {}):
|
| 1039 |
+
continue
|
| 1040 |
+
ld = overall_metrics['lesion'][key]
|
| 1041 |
+
print(f"\n [{class_name}]")
|
| 1042 |
+
print(f" GT Lesions : {ld['n_gt_lesions']}")
|
| 1043 |
+
print(f" Predicted Lesions : {ld['n_pred_lesions']}")
|
| 1044 |
+
print(f" TP Lesions : {ld['tp_lesions']}")
|
| 1045 |
+
print(f" FP Lesions : {ld['fp_lesions']}")
|
| 1046 |
+
print(f" FN Lesions : {ld['fn_lesions']}")
|
| 1047 |
+
print(f" Lesion Sensitivity : {ld['lesion_sensitivity']:.4f}")
|
| 1048 |
+
print(f" Lesion Precision : {ld['lesion_precision']:.4f}")
|
| 1049 |
+
print(f" Lesion F1 : {ld['lesion_f1']:.4f}")
|
| 1050 |
+
|
| 1051 |
+
print(f"\n [Summary across foreground classes]")
|
| 1052 |
+
print(f" Total GT Lesions : {overall_metrics['lesion']['total_n_gt_lesions']}")
|
| 1053 |
+
print(f" Total Pred Lesions : {overall_metrics['lesion']['total_n_pred_lesions']}")
|
| 1054 |
+
print(f" Total TP Lesions : {overall_metrics['lesion']['total_tp_lesions']}")
|
| 1055 |
+
print(f" Total FP Lesions : {overall_metrics['lesion']['total_fp_lesions']}")
|
| 1056 |
+
print(f" Total FN Lesions : {overall_metrics['lesion']['total_fn_lesions']}")
|
| 1057 |
+
print(f" Mean Lesion Sensitivity : {overall_metrics['lesion']['mean_lesion_sensitivity']:.4f}")
|
| 1058 |
+
print(f" Mean Lesion Precision : {overall_metrics['lesion']['mean_lesion_precision']:.4f}")
|
| 1059 |
+
print(f" Mean Lesion F1 : {overall_metrics['lesion']['mean_lesion_f1']:.4f}")
|
| 1060 |
+
print("="*70 + "\n")
|
| 1061 |
+
|
| 1062 |
+
# Save all metrics to JSON
|
| 1063 |
+
metrics_file = config.metrics_dir / "test_metrics_complete.json"
|
| 1064 |
+
|
| 1065 |
+
def convert_to_serializable(obj):
|
| 1066 |
+
"""Convert numpy types to Python native types"""
|
| 1067 |
+
if isinstance(obj, dict):
|
| 1068 |
+
return {k: convert_to_serializable(v) for k, v in obj.items()}
|
| 1069 |
+
elif isinstance(obj, (np.integer, np.int64, np.int32)):
|
| 1070 |
+
return int(obj)
|
| 1071 |
+
elif isinstance(obj, (np.floating, np.float64, np.float32)):
|
| 1072 |
+
return float(obj)
|
| 1073 |
+
elif isinstance(obj, np.ndarray):
|
| 1074 |
+
return obj.tolist()
|
| 1075 |
+
else:
|
| 1076 |
+
return obj
|
| 1077 |
+
|
| 1078 |
+
metrics_to_save = {
|
| 1079 |
+
'config': {
|
| 1080 |
+
'variant': int(config.variant),
|
| 1081 |
+
'preprocessing': config.preprocessing,
|
| 1082 |
+
'class_scenario': config.class_scenario,
|
| 1083 |
+
'fold_id': int(config.fold_id),
|
| 1084 |
+
'num_classes': int(config.num_classes),
|
| 1085 |
+
'class_names': config.class_names,
|
| 1086 |
+
'architecture_name': config.architecture_name,
|
| 1087 |
+
'model_name': config.model_name,
|
| 1088 |
+
'test_samples': int(test_size)
|
| 1089 |
+
},
|
| 1090 |
+
'metrics': convert_to_serializable(overall_metrics)
|
| 1091 |
+
}
|
| 1092 |
+
|
| 1093 |
+
with open(metrics_file, 'w') as f:
|
| 1094 |
+
json.dump(metrics_to_save, f, indent=2)
|
| 1095 |
+
|
| 1096 |
+
print(f"\n✅ All metrics saved to: {metrics_file}")
|
| 1097 |
+
# print(f"✅ Predictions saved to: {config.predictions_dir}")
|
| 1098 |
+
print(f"✅ Visualizations saved to: {config.visualizations_dir}")
|
| 1099 |
+
|
| 1100 |
+
# Return results
|
| 1101 |
+
return {
|
| 1102 |
+
'patients_results': patient_results,
|
| 1103 |
+
'metrics': overall_metrics,
|
| 1104 |
+
'rich_metrics': overall_metrics_rich
|
| 1105 |
+
}
|
| 1106 |
+
|
| 1107 |
+
|
| 1108 |
+
###################### Main Execution ######################
|
| 1109 |
+
|
| 1110 |
+
if __name__ == "__main__":
|
| 1111 |
+
# Run inference
|
| 1112 |
+
|
| 1113 |
+
preprocess_options = ['standard'] # ['zoomed', 'standard']
|
| 1114 |
+
scenarios = ['3class'] # ['3class', '4class']
|
| 1115 |
+
fold_numbers = list(np.array([0, 1, 2, 3]))
|
| 1116 |
+
|
| 1117 |
+
for fold_number in fold_numbers:
|
| 1118 |
+
for preprocess_option in preprocess_options:
|
| 1119 |
+
for scenario in scenarios:
|
| 1120 |
+
|
| 1121 |
+
config = InferenceConfig(
|
| 1122 |
+
variant=1,
|
| 1123 |
+
preprocessing=preprocess_option,
|
| 1124 |
+
class_scenario=scenario,
|
| 1125 |
+
fold_id=fold_number,
|
| 1126 |
+
model_name='best_dice_model.h5',
|
| 1127 |
+
architecture_name='unet' # a choice from ['unet', 'attnunet', 'dlv3unet', 'transunet']
|
| 1128 |
+
)
|
| 1129 |
+
|
| 1130 |
+
results = run_inference(config)
|
| 1131 |
+
|
| 1132 |
+
# ── Error Analysis ──────────��───────────────────────────
|
| 1133 |
+
error_results = run_error_analysis(
|
| 1134 |
+
results=results,
|
| 1135 |
+
config=config,
|
| 1136 |
+
top_n_slices=300, # visualise N hardest slices
|
| 1137 |
+
top_n_patients=20, # patient summary plots
|
| 1138 |
+
fg_dice_weight=0.7, # tunable ranking weights
|
| 1139 |
+
error_rate_weight=0.2,
|
| 1140 |
+
confidence_weight=0.2,
|
| 1141 |
+
)
|
| 1142 |
+
# ────────────────────────────────────────────────────────
|
| 1143 |
+
|
| 1144 |
+
print("\n" + "="*70)
|
| 1145 |
+
print("INFERENCE + ERROR ANALYSIS COMPLETE")
|
| 1146 |
+
print("="*70)
|
models/for_WMH_Vent/model_training_scripts/p4_run_experiments_all.py
ADDED
|
@@ -0,0 +1,576 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
P4 Article - Run Multiple Variant Experiments
|
| 3 |
+
Updated runner script supporting all models
|
| 4 |
+
|
| 5 |
+
Supports:
|
| 6 |
+
- Variant 1: Baseline U-Net
|
| 7 |
+
- Variant 2: Attention U-Net
|
| 8 |
+
- Variant 3: DeepLabV3+ U-Net
|
| 9 |
+
- Variant 4: Trans U-Net
|
| 10 |
+
|
| 11 |
+
Usage:
|
| 12 |
+
# Single experiment
|
| 13 |
+
python p4_run_experiments_all.py --variant 2 --fold 0 --scenario standard_3class
|
| 14 |
+
|
| 15 |
+
# All scenarios for one variant+fold
|
| 16 |
+
python p4_run_experiments_all.py --variant 2 --fold 0
|
| 17 |
+
|
| 18 |
+
# All scenarios for one variant (all folds)
|
| 19 |
+
python p4_run_experiments_all.py --variant 2
|
| 20 |
+
|
| 21 |
+
# All scenarios (all folds and all variants)
|
| 22 |
+
python p4_run_experiments_all.py
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
import sys
|
| 26 |
+
import argparse
|
| 27 |
+
import subprocess
|
| 28 |
+
from pathlib import Path
|
| 29 |
+
import tensorflow as tf
|
| 30 |
+
import gc
|
| 31 |
+
from tensorflow.keras import backend as K
|
| 32 |
+
|
| 33 |
+
import p4_unet_viz
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def clear_gpu_memory():
|
| 37 |
+
"""Comprehensive GPU memory cleanup between experiments"""
|
| 38 |
+
print("\n" + "="*70)
|
| 39 |
+
print("CLEANING UP GPU MEMORY")
|
| 40 |
+
print("="*70)
|
| 41 |
+
|
| 42 |
+
# Clear Keras session
|
| 43 |
+
K.clear_session()
|
| 44 |
+
print("✅ Cleared Keras session")
|
| 45 |
+
|
| 46 |
+
# Force garbage collection
|
| 47 |
+
gc.collect()
|
| 48 |
+
print("✅ Ran garbage collection")
|
| 49 |
+
|
| 50 |
+
# Reset TensorFlow graphs
|
| 51 |
+
tf.compat.v1.reset_default_graph()
|
| 52 |
+
print("✅ Reset default graph")
|
| 53 |
+
|
| 54 |
+
# Additional cleanup for TF 2.x
|
| 55 |
+
try:
|
| 56 |
+
# Clear any cached tensors
|
| 57 |
+
tf.config.experimental.reset_memory_stats('GPU:0')
|
| 58 |
+
print("✅ Reset GPU memory stats")
|
| 59 |
+
except:
|
| 60 |
+
pass
|
| 61 |
+
|
| 62 |
+
print("="*70 + "\n")
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def run_single_experiment(variant: int,
|
| 66 |
+
preprocessing: str,
|
| 67 |
+
class_scenario: str,
|
| 68 |
+
fold_id: int) -> bool:
|
| 69 |
+
"""
|
| 70 |
+
Run a single experiment for specified variant
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
variant: 1 (baseline u-net) or 2 (attention u-net) or 3 (deeplabv3+ u-net) or 4 (trans u-net)
|
| 74 |
+
preprocessing: 'standard' or 'zoomed'
|
| 75 |
+
class_scenario: '3class' or '4class'
|
| 76 |
+
fold_id: 0-4
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
True if successful, False otherwise
|
| 80 |
+
"""
|
| 81 |
+
print("\n" + "="*80)
|
| 82 |
+
print(f"RUNNING: Variant {variant} | {preprocessing} | {class_scenario} | Fold {fold_id}")
|
| 83 |
+
print("="*80 + "\n")
|
| 84 |
+
|
| 85 |
+
try:
|
| 86 |
+
if variant == 1:
|
| 87 |
+
# Baseline unet
|
| 88 |
+
from p4_variant_all_net import ExperimentConfig, train_net
|
| 89 |
+
|
| 90 |
+
config = ExperimentConfig(
|
| 91 |
+
variant=variant,
|
| 92 |
+
preprocessing=preprocessing,
|
| 93 |
+
class_scenario=class_scenario,
|
| 94 |
+
fold_id=fold_id,
|
| 95 |
+
architecture_name='unet'
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
history, history_path = train_net(config)
|
| 99 |
+
p4_unet_viz.main_viz(history_path)
|
| 100 |
+
|
| 101 |
+
# Run Inference
|
| 102 |
+
from p4_inference import InferenceConfig, run_inference, run_error_analysis
|
| 103 |
+
|
| 104 |
+
config = InferenceConfig(
|
| 105 |
+
variant=variant,
|
| 106 |
+
preprocessing=preprocessing,
|
| 107 |
+
class_scenario=class_scenario,
|
| 108 |
+
fold_id=fold_id,
|
| 109 |
+
model_name='best_dice_model.h5',
|
| 110 |
+
architecture_name='unet'
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
results = run_inference(config)
|
| 114 |
+
|
| 115 |
+
# ── Error Analysis ──────────────────────────────────────
|
| 116 |
+
error_results = run_error_analysis(
|
| 117 |
+
results=results,
|
| 118 |
+
config=config,
|
| 119 |
+
top_n_slices=30, # visualise N hardest slices
|
| 120 |
+
top_n_patients=10, # patient summary plots
|
| 121 |
+
fg_dice_weight=0.6, # tunable ranking weights
|
| 122 |
+
error_rate_weight=0.2,
|
| 123 |
+
confidence_weight=0.2,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
elif variant == 2:
|
| 127 |
+
# Attention unet
|
| 128 |
+
from p4_variant_all_net import ExperimentConfig, train_net
|
| 129 |
+
|
| 130 |
+
config = ExperimentConfig(
|
| 131 |
+
variant=variant,
|
| 132 |
+
preprocessing=preprocessing,
|
| 133 |
+
class_scenario=class_scenario,
|
| 134 |
+
fold_id=fold_id,
|
| 135 |
+
architecture_name='attnunet'
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
history, history_path = train_net(config)
|
| 139 |
+
p4_unet_viz.main_viz(history_path)
|
| 140 |
+
|
| 141 |
+
# Run Inference
|
| 142 |
+
from p4_inference import InferenceConfig, run_inference, run_error_analysis
|
| 143 |
+
|
| 144 |
+
config = InferenceConfig(
|
| 145 |
+
variant=variant,
|
| 146 |
+
preprocessing=preprocessing,
|
| 147 |
+
class_scenario=class_scenario,
|
| 148 |
+
fold_id=fold_id,
|
| 149 |
+
model_name='best_dice_model.h5',
|
| 150 |
+
architecture_name='attnunet'
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
results = run_inference(config)
|
| 154 |
+
|
| 155 |
+
# ── Error Analysis ──────────────────────────────────────
|
| 156 |
+
error_results = run_error_analysis(
|
| 157 |
+
results=results,
|
| 158 |
+
config=config,
|
| 159 |
+
top_n_slices=30, # visualise N hardest slices
|
| 160 |
+
top_n_patients=10, # patient summary plots
|
| 161 |
+
fg_dice_weight=0.6, # tunable ranking weights
|
| 162 |
+
error_rate_weight=0.2,
|
| 163 |
+
confidence_weight=0.2,
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
elif variant == 3:
|
| 167 |
+
# DeepLabV3+ unet
|
| 168 |
+
from p4_variant_all_net import ExperimentConfig, train_net
|
| 169 |
+
|
| 170 |
+
config = ExperimentConfig(
|
| 171 |
+
variant=variant,
|
| 172 |
+
preprocessing=preprocessing,
|
| 173 |
+
class_scenario=class_scenario,
|
| 174 |
+
fold_id=fold_id,
|
| 175 |
+
architecture_name='dlv3unet'
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
history, history_path = train_net(config)
|
| 179 |
+
p4_unet_viz.main_viz(history_path)
|
| 180 |
+
|
| 181 |
+
# Run Inference
|
| 182 |
+
from p4_inference import InferenceConfig, run_inference, run_error_analysis
|
| 183 |
+
|
| 184 |
+
config = InferenceConfig(
|
| 185 |
+
variant=variant,
|
| 186 |
+
preprocessing=preprocessing,
|
| 187 |
+
class_scenario=class_scenario,
|
| 188 |
+
fold_id=fold_id,
|
| 189 |
+
model_name='best_dice_model.h5',
|
| 190 |
+
architecture_name='dlv3unet'
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
results = run_inference(config)
|
| 194 |
+
|
| 195 |
+
# ── Error Analysis ──────────────────────────────────────
|
| 196 |
+
error_results = run_error_analysis(
|
| 197 |
+
results=results,
|
| 198 |
+
config=config,
|
| 199 |
+
top_n_slices=30, # visualise N hardest slices
|
| 200 |
+
top_n_patients=10, # patient summary plots
|
| 201 |
+
fg_dice_weight=0.6, # tunable ranking weights
|
| 202 |
+
error_rate_weight=0.2,
|
| 203 |
+
confidence_weight=0.2,
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
elif variant == 4:
|
| 207 |
+
# Trans unet
|
| 208 |
+
from p4_variant_all_net import ExperimentConfig, train_net
|
| 209 |
+
|
| 210 |
+
config = ExperimentConfig(
|
| 211 |
+
variant=variant,
|
| 212 |
+
preprocessing=preprocessing,
|
| 213 |
+
class_scenario=class_scenario,
|
| 214 |
+
fold_id=fold_id,
|
| 215 |
+
architecture_name='transunet'
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
history, history_path = train_net(config)
|
| 219 |
+
p4_unet_viz.main_viz(history_path)
|
| 220 |
+
|
| 221 |
+
# Run Inference
|
| 222 |
+
from p4_inference import InferenceConfig, run_inference, run_error_analysis
|
| 223 |
+
|
| 224 |
+
config = InferenceConfig(
|
| 225 |
+
variant=variant,
|
| 226 |
+
preprocessing=preprocessing,
|
| 227 |
+
class_scenario=class_scenario,
|
| 228 |
+
fold_id=fold_id,
|
| 229 |
+
model_name='best_dice_model.h5',
|
| 230 |
+
architecture_name='transunet'
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
results = run_inference(config)
|
| 234 |
+
|
| 235 |
+
# ── Error Analysis ──────────────────────────────────────
|
| 236 |
+
error_results = run_error_analysis(
|
| 237 |
+
results=results,
|
| 238 |
+
config=config,
|
| 239 |
+
top_n_slices=30, # visualise N hardest slices
|
| 240 |
+
top_n_patients=10, # patient summary plots
|
| 241 |
+
fg_dice_weight=0.6, # tunable ranking weights
|
| 242 |
+
error_rate_weight=0.2,
|
| 243 |
+
confidence_weight=0.2,
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
else:
|
| 247 |
+
raise ValueError(f"Unknown variant: {variant}")
|
| 248 |
+
|
| 249 |
+
print(f"\n✅ Experiment completed successfully!")
|
| 250 |
+
return True
|
| 251 |
+
|
| 252 |
+
except Exception as e:
|
| 253 |
+
print(f"\n❌ Experiment failed with error:")
|
| 254 |
+
print(f" {str(e)}")
|
| 255 |
+
import traceback
|
| 256 |
+
traceback.print_exc()
|
| 257 |
+
return False
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def run_all_scenarios_for_variant_fold(variant: int, fold_id: int) -> dict:
|
| 261 |
+
"""
|
| 262 |
+
Run all 4 scenarios for a given variant and fold
|
| 263 |
+
|
| 264 |
+
Args:
|
| 265 |
+
variant: 1 (baseline u-net) or 2 (attention u-net) or 3 (deeplabv3+ u-net) or 4 (trans u-net)
|
| 266 |
+
fold_id: 0-4
|
| 267 |
+
|
| 268 |
+
Returns:
|
| 269 |
+
Dictionary with results for each scenario
|
| 270 |
+
"""
|
| 271 |
+
print("\n" + "="*80)
|
| 272 |
+
print(f"RUNNING ALL SCENARIOS FOR VARIANT {variant}, FOLD {fold_id}")
|
| 273 |
+
print("="*80)
|
| 274 |
+
print("\nTotal experiments: 4")
|
| 275 |
+
print(" 1. standard + 3class")
|
| 276 |
+
print(" 2. standard + 4class")
|
| 277 |
+
print(" 3. zoomed + 3class")
|
| 278 |
+
print(" 4. zoomed + 4class")
|
| 279 |
+
print("\n" + "="*80 + "\n")
|
| 280 |
+
|
| 281 |
+
experiments = [
|
| 282 |
+
{'preprocessing': 'zoomed', 'class_scenario': '4class'},
|
| 283 |
+
{'preprocessing': 'standard', 'class_scenario': '4class'},
|
| 284 |
+
{'preprocessing': 'zoomed', 'class_scenario': '3class'},
|
| 285 |
+
{'preprocessing': 'standard', 'class_scenario': '3class'},
|
| 286 |
+
]
|
| 287 |
+
|
| 288 |
+
results = {}
|
| 289 |
+
|
| 290 |
+
for idx, scenario in enumerate(experiments, 1):
|
| 291 |
+
print(f"\n{'#'*80}")
|
| 292 |
+
print(f"SCENARIO {idx}/4: {scenario['preprocessing']} + {scenario['class_scenario']}")
|
| 293 |
+
print(f"{'#'*80}\n")
|
| 294 |
+
|
| 295 |
+
# Run in subprocess for complete memory isolation
|
| 296 |
+
import subprocess
|
| 297 |
+
import sys
|
| 298 |
+
|
| 299 |
+
cmd = [
|
| 300 |
+
sys.executable,
|
| 301 |
+
'p4_run_experiments_all.py',
|
| 302 |
+
'--variant', str(variant),
|
| 303 |
+
'--fold', str(fold_id),
|
| 304 |
+
'--scenario', f"{scenario['preprocessing']}_{scenario['class_scenario']}"
|
| 305 |
+
]
|
| 306 |
+
|
| 307 |
+
print(f"Running command: {' '.join(cmd)}\n")
|
| 308 |
+
|
| 309 |
+
try:
|
| 310 |
+
# Run experiment in separate process
|
| 311 |
+
result = subprocess.run(cmd, check=True, capture_output=False)
|
| 312 |
+
|
| 313 |
+
if result.returncode == 0:
|
| 314 |
+
exp_name = f"v{variant}_{scenario['preprocessing']}_{scenario['class_scenario']}_fold{fold_id}"
|
| 315 |
+
results[exp_name] = {'status': 'SUCCESS'}
|
| 316 |
+
print(f"\n✅ {exp_name} completed successfully")
|
| 317 |
+
else:
|
| 318 |
+
raise Exception(f"Process returned code {result.returncode}")
|
| 319 |
+
|
| 320 |
+
except subprocess.CalledProcessError as e:
|
| 321 |
+
exp_name = f"v{variant}_{scenario['preprocessing']}_{scenario['class_scenario']}_fold{fold_id}"
|
| 322 |
+
print(f"\n❌ Error in {scenario['preprocessing']} + {scenario['class_scenario']}")
|
| 323 |
+
print(f" Error: {str(e)}")
|
| 324 |
+
results[exp_name] = {
|
| 325 |
+
'status': 'FAILED',
|
| 326 |
+
'error': str(e)
|
| 327 |
+
}
|
| 328 |
+
|
| 329 |
+
# Ask user if they want to continue
|
| 330 |
+
response = input("\nContinue with remaining experiments? (y/n): ")
|
| 331 |
+
if response.lower() != 'y':
|
| 332 |
+
print("Stopping experiments...")
|
| 333 |
+
break
|
| 334 |
+
|
| 335 |
+
# Brief pause between experiments
|
| 336 |
+
import time
|
| 337 |
+
print("\n⏳ Waiting 5 seconds before next experiment...")
|
| 338 |
+
time.sleep(5)
|
| 339 |
+
|
| 340 |
+
# Summary
|
| 341 |
+
print("\n" + "="*80)
|
| 342 |
+
print(f"VARIANT {variant}, FOLD {fold_id} - SUMMARY")
|
| 343 |
+
print("="*80)
|
| 344 |
+
|
| 345 |
+
for exp_name, result in results.items():
|
| 346 |
+
status_icon = "✅" if result['status'] == 'SUCCESS' else "❌"
|
| 347 |
+
print(f"{status_icon} {exp_name}")
|
| 348 |
+
|
| 349 |
+
print("\n" + "="*80 + "\n")
|
| 350 |
+
|
| 351 |
+
return results
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def run_all_folds_for_variant(variant: int) -> dict:
|
| 355 |
+
"""
|
| 356 |
+
Run all scenarios for all folds for a given variant
|
| 357 |
+
Run all 4 experiments for all 5 folds
|
| 358 |
+
Total: 4 scenarios × 5 folds = 20 training runs
|
| 359 |
+
|
| 360 |
+
Args:
|
| 361 |
+
variant: 1 (baseline u-net) or 2 (attention u-net) or 3 (deeplabv3+ u-net) or 4 (trans u-net)
|
| 362 |
+
|
| 363 |
+
Returns:
|
| 364 |
+
Dictionary with results for all folds
|
| 365 |
+
"""
|
| 366 |
+
print("\n" + "="*80)
|
| 367 |
+
print(f"RUNNING ALL FOLDS FOR VARIANT {variant}")
|
| 368 |
+
print("="*80)
|
| 369 |
+
print("\nTotal experiments: 4 scenarios × 5 folds = 20 training runs")
|
| 370 |
+
print("Estimated time: ~0.7 hour per experiment (with 60 epochs)")
|
| 371 |
+
print("Total estimated time: 10-20 hours")
|
| 372 |
+
print("\n" + "="*80 + "\n")
|
| 373 |
+
|
| 374 |
+
response = input("This will take a long time. Continue? (y/n): ")
|
| 375 |
+
if response.lower() != 'y':
|
| 376 |
+
print("Cancelled.")
|
| 377 |
+
return {}
|
| 378 |
+
|
| 379 |
+
all_results = {}
|
| 380 |
+
|
| 381 |
+
for fold_id in range(5):
|
| 382 |
+
print(f"\n{'='*80}")
|
| 383 |
+
print(f"STARTING FOLD {fold_id}")
|
| 384 |
+
print(f"{'='*80}\n")
|
| 385 |
+
|
| 386 |
+
fold_results = run_all_scenarios_for_variant_fold(variant, fold_id)
|
| 387 |
+
all_results[f'fold_{fold_id}'] = fold_results
|
| 388 |
+
|
| 389 |
+
# Final summary
|
| 390 |
+
print("\n" + "="*80)
|
| 391 |
+
print(f"VARIANT {variant} - ALL FOLDS COMPLETE")
|
| 392 |
+
print("="*80)
|
| 393 |
+
|
| 394 |
+
for fold_id in range(5):
|
| 395 |
+
fold_key = f'fold_{fold_id}'
|
| 396 |
+
if fold_key in all_results:
|
| 397 |
+
print(f"\nFold {fold_id}:")
|
| 398 |
+
for exp_name, result in all_results[fold_key].items():
|
| 399 |
+
status_icon = "✅" if result['status'] == 'SUCCESS' else "❌"
|
| 400 |
+
print(f" {status_icon} {exp_name}")
|
| 401 |
+
|
| 402 |
+
print("\n" + "="*80 + "\n")
|
| 403 |
+
|
| 404 |
+
return all_results
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def compare_variants(fold_id: int = 0):
|
| 408 |
+
"""
|
| 409 |
+
Compare results between baseline and attention variants and newloss variants
|
| 410 |
+
|
| 411 |
+
Args:
|
| 412 |
+
fold_id: Fold to compare (0-4)
|
| 413 |
+
"""
|
| 414 |
+
print("\n" + "="*80)
|
| 415 |
+
print(f"COMPARING VARIANTS FOR FOLD {fold_id}")
|
| 416 |
+
print("="*80)
|
| 417 |
+
|
| 418 |
+
import json
|
| 419 |
+
|
| 420 |
+
scenarios = [
|
| 421 |
+
{'preprocessing': 'standard', 'class_scenario': '3class'},
|
| 422 |
+
{'preprocessing': 'standard', 'class_scenario': '4class'},
|
| 423 |
+
{'preprocessing': 'zoomed', 'class_scenario': '3class'},
|
| 424 |
+
{'preprocessing': 'zoomed', 'class_scenario': '4class'},
|
| 425 |
+
]
|
| 426 |
+
|
| 427 |
+
results_dir = Path(f"results_fold_{fold_id}")
|
| 428 |
+
|
| 429 |
+
for scenario in scenarios:
|
| 430 |
+
print(f"\n{scenario['preprocessing']} + {scenario['class_scenario']}:")
|
| 431 |
+
print("-" * 60)
|
| 432 |
+
|
| 433 |
+
# Baseline (variant 1)
|
| 434 |
+
baseline_dir = results_dir / "models" / f"{scenario['preprocessing']}_{scenario['class_scenario']}" / f"fold_{fold_id}"
|
| 435 |
+
baseline_history = baseline_dir / "history.json"
|
| 436 |
+
|
| 437 |
+
# Attention (variant 2)
|
| 438 |
+
attention_dir = results_dir / "models" / f"{scenario['preprocessing']}_{scenario['class_scenario']}" / f"fold_{fold_id}_variant2"
|
| 439 |
+
attention_history = attention_dir / "history.json"
|
| 440 |
+
|
| 441 |
+
# Attention (variant 3)
|
| 442 |
+
newloss_dir = results_dir / "models" / f"{scenario['preprocessing']}_{scenario['class_scenario']}" / f"fold_{fold_id}_variant3"
|
| 443 |
+
newloss_history = newloss_dir / "history.json"
|
| 444 |
+
|
| 445 |
+
if baseline_history.exists() and attention_history.exists() and newloss_history.exists():
|
| 446 |
+
with open(baseline_history, 'r') as f:
|
| 447 |
+
baseline_data = json.load(f)
|
| 448 |
+
|
| 449 |
+
with open(attention_history, 'r') as f:
|
| 450 |
+
attention_data = json.load(f)
|
| 451 |
+
|
| 452 |
+
with open(newloss_history, 'r') as f:
|
| 453 |
+
newloss_data = json.load(f)
|
| 454 |
+
|
| 455 |
+
# Compare final validation losses
|
| 456 |
+
baseline_val = baseline_data['val_loss'][-1]
|
| 457 |
+
attention_val = attention_data['val_loss'][-1]
|
| 458 |
+
newloss_val = newloss_data['val_loss'][-1]
|
| 459 |
+
|
| 460 |
+
improvement_1_2 = ((baseline_val - attention_val) / baseline_val) * 100
|
| 461 |
+
improvement_1_3 = ((baseline_val - newloss_val) / baseline_val) * 100
|
| 462 |
+
improvement_2_3 = ((attention_val - newloss_val) / attention_val) * 100
|
| 463 |
+
|
| 464 |
+
print(f" Baseline Val Loss: {baseline_val:.4f}")
|
| 465 |
+
print(f" Attention Val Loss: {attention_val:.4f}")
|
| 466 |
+
print(f" NewLoss Val Loss: {newloss_val:.4f}")
|
| 467 |
+
print(f" Improvement by V2 on V1: {improvement_1_2:+.2f}%")
|
| 468 |
+
print(f" Improvement by V3 on V1: {improvement_1_3:+.2f}%")
|
| 469 |
+
print(f" Improvement by V3 on V2: {improvement_2_3:+.2f}%")
|
| 470 |
+
|
| 471 |
+
else:
|
| 472 |
+
if not baseline_history.exists():
|
| 473 |
+
print(f" ⚠️ Baseline results not found")
|
| 474 |
+
if not attention_history.exists():
|
| 475 |
+
print(f" ⚠️ Attention results not found")
|
| 476 |
+
if not newloss_history.exists():
|
| 477 |
+
print(f" ⚠️ NewLoss results not found")
|
| 478 |
+
|
| 479 |
+
print("\n" + "="*80 + "\n")
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
def main():
|
| 483 |
+
"""Main entry point with argument parsing"""
|
| 484 |
+
parser = argparse.ArgumentParser(
|
| 485 |
+
description='Run P4 experiments for multiple variants',
|
| 486 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 487 |
+
epilog="""
|
| 488 |
+
Examples:
|
| 489 |
+
# Single experiment
|
| 490 |
+
python p4_run_experiments_all.py --variant 2 --fold 0 --scenario standard_3class
|
| 491 |
+
|
| 492 |
+
# All scenarios for variant 2, fold 0
|
| 493 |
+
python p4_run_experiments_all.py --variant 2 --fold 0
|
| 494 |
+
|
| 495 |
+
# All folds for variant 3
|
| 496 |
+
python p4_run_experiments_all.py --variant 2
|
| 497 |
+
|
| 498 |
+
# Compare results
|
| 499 |
+
python p4_run_experiments_all.py --compare --fold 0
|
| 500 |
+
"""
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
parser.add_argument(
|
| 504 |
+
'--variant',
|
| 505 |
+
type=int,
|
| 506 |
+
choices=[1, 2, 3, 4],
|
| 507 |
+
help='variant: 1 (baseline u-net) or 2 (attention u-net) or 3 (deeplabv3+ u-net) or 4 (trans u-net)'
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
parser.add_argument(
|
| 511 |
+
'--fold',
|
| 512 |
+
type=int,
|
| 513 |
+
choices=[0, 1, 2, 3, 4],
|
| 514 |
+
help='Specific fold to train (0-4)'
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
parser.add_argument(
|
| 518 |
+
'--scenario',
|
| 519 |
+
type=str,
|
| 520 |
+
choices=['standard_3class', 'standard_4class', 'zoomed_3class', 'zoomed_4class'],
|
| 521 |
+
help='Specific scenario to train'
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
parser.add_argument(
|
| 525 |
+
'--compare',
|
| 526 |
+
action='store_true',
|
| 527 |
+
help='Compare results between variants'
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
args = parser.parse_args()
|
| 531 |
+
|
| 532 |
+
# Handle comparison mode (NOT READY YET!)
|
| 533 |
+
if args.compare:
|
| 534 |
+
fold_id = args.fold if args.fold is not None else 0
|
| 535 |
+
compare_variants(fold_id)
|
| 536 |
+
return
|
| 537 |
+
|
| 538 |
+
# Validate arguments
|
| 539 |
+
if args.variant is None:
|
| 540 |
+
parser.error("--variant is required (unless using --compare)")
|
| 541 |
+
|
| 542 |
+
# Single experiment
|
| 543 |
+
if args.scenario is not None:
|
| 544 |
+
preprocessing, class_scenario = args.scenario.split('_')
|
| 545 |
+
fold_id = args.fold if args.fold is not None else 0
|
| 546 |
+
|
| 547 |
+
print(f"\nRunning single experiment:")
|
| 548 |
+
print(f" Variant: {args.variant}")
|
| 549 |
+
print(f" Fold: {fold_id}")
|
| 550 |
+
print(f" Preprocessing: {preprocessing}")
|
| 551 |
+
print(f" Class scenario: {class_scenario}\n")
|
| 552 |
+
|
| 553 |
+
success = run_single_experiment(
|
| 554 |
+
variant=args.variant,
|
| 555 |
+
preprocessing=preprocessing,
|
| 556 |
+
class_scenario=class_scenario,
|
| 557 |
+
fold_id=fold_id
|
| 558 |
+
)
|
| 559 |
+
|
| 560 |
+
if success:
|
| 561 |
+
print("\n✅ Experiment complete!")
|
| 562 |
+
else:
|
| 563 |
+
print("\n❌ Experiment failed!")
|
| 564 |
+
sys.exit(1)
|
| 565 |
+
|
| 566 |
+
# All scenarios for specific fold
|
| 567 |
+
elif args.fold is not None:
|
| 568 |
+
run_all_scenarios_for_variant_fold(args.variant, args.fold)
|
| 569 |
+
|
| 570 |
+
# All scenarios for all folds
|
| 571 |
+
else:
|
| 572 |
+
run_all_folds_for_variant(args.variant)
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
if __name__ == "__main__":
|
| 576 |
+
main()
|
models/for_WMH_Vent/model_training_scripts/p4_unet_viz.py
ADDED
|
@@ -0,0 +1,640 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
P4 - All U-Net models with Adaptive Loss (WCE + UFL)
|
| 3 |
+
|
| 4 |
+
WMH and Ventricles Segmentation with U-Net Models - Journal Paper Implementation
|
| 5 |
+
Three-class segmentation: Background vs Ventricles vs Abnormal WMH
|
| 6 |
+
Professional results saving and visualization for publication
|
| 7 |
+
|
| 8 |
+
This relates to our article:
|
| 9 |
+
"Deep Learning-Based Neuroanatomical Profiling Reveals Detailed Brain Changes:
|
| 10 |
+
A Large-Scale Multiple Sclerosis Study"
|
| 11 |
+
|
| 12 |
+
Features:
|
| 13 |
+
- Visualization of Results
|
| 14 |
+
|
| 15 |
+
Authors:
|
| 16 |
+
"Mahdi Bashiri Bawil, Mousa Shamsi, Abolhassan Shakeri Bavil"
|
| 17 |
+
|
| 18 |
+
Developer:
|
| 19 |
+
"Mahdi Bashiri Bawil"
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import os
|
| 23 |
+
import json
|
| 24 |
+
import matplotlib.pyplot as plt
|
| 25 |
+
import numpy as np
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def load_history(filepath):
|
| 30 |
+
"""Load training history from JSON file."""
|
| 31 |
+
with open(filepath, 'r') as f:
|
| 32 |
+
return json.load(f)
|
| 33 |
+
|
| 34 |
+
def detect_num_classes(history):
|
| 35 |
+
"""Detect number of classes from val_metrics."""
|
| 36 |
+
if not history['val_metrics']:
|
| 37 |
+
return 3
|
| 38 |
+
first_metric = history['val_metrics'][0]
|
| 39 |
+
# Count only class_X keys, not 'mean'
|
| 40 |
+
num_classes = len([k for k in first_metric['dice'].keys() if k.startswith('class_')])
|
| 41 |
+
return num_classes
|
| 42 |
+
|
| 43 |
+
def get_class_names(num_classes):
|
| 44 |
+
"""Get class names based on number of classes."""
|
| 45 |
+
if num_classes == 3:
|
| 46 |
+
return {
|
| 47 |
+
'class_0': 'Background',
|
| 48 |
+
'class_1': 'Ventricles',
|
| 49 |
+
'class_2': 'Abnormal WMH'
|
| 50 |
+
}
|
| 51 |
+
elif num_classes == 4:
|
| 52 |
+
return {
|
| 53 |
+
'class_0': 'Background',
|
| 54 |
+
'class_1': 'Ventricles',
|
| 55 |
+
'class_2': 'Normal WMH',
|
| 56 |
+
'class_3': 'Abnormal WMH'
|
| 57 |
+
}
|
| 58 |
+
else:
|
| 59 |
+
return {f'class_{i}': f'Class {i}' for i in range(num_classes)}
|
| 60 |
+
|
| 61 |
+
def convert_to_native_types(obj):
|
| 62 |
+
"""Recursively convert numpy types to native Python types for JSON serialization."""
|
| 63 |
+
if isinstance(obj, np.integer):
|
| 64 |
+
return int(obj)
|
| 65 |
+
elif isinstance(obj, np.floating):
|
| 66 |
+
return float(obj)
|
| 67 |
+
elif isinstance(obj, np.ndarray):
|
| 68 |
+
return obj.tolist()
|
| 69 |
+
elif isinstance(obj, dict):
|
| 70 |
+
return {key: convert_to_native_types(value) for key, value in obj.items()}
|
| 71 |
+
elif isinstance(obj, list):
|
| 72 |
+
return [convert_to_native_types(item) for item in obj]
|
| 73 |
+
else:
|
| 74 |
+
return obj
|
| 75 |
+
|
| 76 |
+
def find_best_epoch(history, num_classes):
|
| 77 |
+
"""
|
| 78 |
+
Find the best epoch based on prioritized criteria:
|
| 79 |
+
1. Highest Dice for abnormal WMH (top priority)
|
| 80 |
+
2. Highest Dice for ventricles (secondary)
|
| 81 |
+
3. Lowest validation loss (tertiary)
|
| 82 |
+
4. ONLY consider epochs where beta > 0.95 (CRITICAL REQUIREMENT)
|
| 83 |
+
|
| 84 |
+
"""
|
| 85 |
+
if not history['val_metrics']:
|
| 86 |
+
return None, {}
|
| 87 |
+
|
| 88 |
+
epochs = range(1, len(history['val_metrics']) + 1)
|
| 89 |
+
if 'beta_value' in history:
|
| 90 |
+
beta_values = history['beta_value']
|
| 91 |
+
else:
|
| 92 |
+
beta_values = [1] * len(history.get('val_loss', []))
|
| 93 |
+
history['beta_value'] = beta_values
|
| 94 |
+
|
| 95 |
+
# Find epochs where beta > 0.95 (CRITICAL FILTER)
|
| 96 |
+
valid_epoch_indices = [i for i, beta in enumerate(beta_values) if beta > 0.95]
|
| 97 |
+
|
| 98 |
+
if not valid_epoch_indices:
|
| 99 |
+
print("⚠️ WARNING: No epochs found with beta > 0.95!")
|
| 100 |
+
print(" Using all epochs for analysis (not recommended).")
|
| 101 |
+
valid_epoch_indices = list(range(len(beta_values)))
|
| 102 |
+
|
| 103 |
+
first_valid_epoch = valid_epoch_indices[0] + 1 if valid_epoch_indices else 1
|
| 104 |
+
|
| 105 |
+
# Determine the key for abnormal WMH
|
| 106 |
+
abnormal_key = 'class_3' if num_classes == 4 else 'class_2'
|
| 107 |
+
ventricles_key = 'class_1'
|
| 108 |
+
|
| 109 |
+
# Extract metrics
|
| 110 |
+
abnormal_dice = [m['dice'][abnormal_key] for m in history['val_metrics']]
|
| 111 |
+
ventricles_dice = [m['dice'][ventricles_key] for m in history['val_metrics']]
|
| 112 |
+
val_losses = history['val_loss']
|
| 113 |
+
|
| 114 |
+
# Find best epoch for abnormal WMH dice (only among valid epochs)
|
| 115 |
+
valid_abnormal_dice = [(i, abnormal_dice[i]) for i in valid_epoch_indices]
|
| 116 |
+
best_abnormal_idx = max(valid_abnormal_dice, key=lambda x: x[1])[0]
|
| 117 |
+
best_abnormal_epoch = best_abnormal_idx + 1
|
| 118 |
+
best_abnormal_dice = abnormal_dice[best_abnormal_idx]
|
| 119 |
+
|
| 120 |
+
# Find best epoch for ventricles dice (only among valid epochs)
|
| 121 |
+
valid_ventricles_dice = [(i, ventricles_dice[i]) for i in valid_epoch_indices]
|
| 122 |
+
best_ventricles_idx = max(valid_ventricles_dice, key=lambda x: x[1])[0]
|
| 123 |
+
best_ventricles_epoch = best_ventricles_idx + 1
|
| 124 |
+
best_ventricles_dice = ventricles_dice[best_ventricles_idx]
|
| 125 |
+
|
| 126 |
+
# Find best epoch for validation loss (only among valid epochs)
|
| 127 |
+
valid_val_losses = [(i, val_losses[i]) for i in valid_epoch_indices]
|
| 128 |
+
best_val_loss_idx = min(valid_val_losses, key=lambda x: x[1])[0]
|
| 129 |
+
best_val_loss_epoch = best_val_loss_idx + 1
|
| 130 |
+
best_val_loss = val_losses[best_val_loss_idx]
|
| 131 |
+
|
| 132 |
+
# Calculate composite score (weighted) - ONLY for valid epochs
|
| 133 |
+
composite_scores = [float('-inf')] * len(abnormal_dice)
|
| 134 |
+
|
| 135 |
+
for i in valid_epoch_indices:
|
| 136 |
+
# Normalize and weight: 60% abnormal dice, 30% ventricles dice, 10% inv val_loss
|
| 137 |
+
norm_abnormal = abnormal_dice[i]
|
| 138 |
+
norm_ventricles = ventricles_dice[i]
|
| 139 |
+
|
| 140 |
+
# Normalize validation loss among valid epochs only
|
| 141 |
+
valid_val_loss_values = [val_losses[j] for j in valid_epoch_indices]
|
| 142 |
+
max_val_loss = max(valid_val_loss_values) if valid_val_loss_values else 1
|
| 143 |
+
norm_val_loss = 1 - (val_losses[i] / max_val_loss) if max_val_loss > 0 else 0
|
| 144 |
+
|
| 145 |
+
composite = 0.6 * norm_abnormal + 0.3 * norm_ventricles + 0.1 * (1 - val_losses[i]) # norm_val_loss
|
| 146 |
+
composite_scores[i] = composite
|
| 147 |
+
|
| 148 |
+
best_overall_idx = int(np.argmax(composite_scores)) # Convert to int
|
| 149 |
+
best_overall_epoch = best_overall_idx + 1
|
| 150 |
+
|
| 151 |
+
# Get all metrics at best epoch
|
| 152 |
+
best_epoch_metrics = history['val_metrics'][best_overall_idx]
|
| 153 |
+
|
| 154 |
+
analysis = {
|
| 155 |
+
'best_overall_epoch': int(best_overall_epoch),
|
| 156 |
+
'best_overall_epoch_idx': int(best_overall_idx),
|
| 157 |
+
'best_abnormal_epoch': int(best_abnormal_epoch),
|
| 158 |
+
'best_abnormal_dice': float(best_abnormal_dice),
|
| 159 |
+
'best_ventricles_epoch': int(best_ventricles_epoch),
|
| 160 |
+
'best_ventricles_dice': float(best_ventricles_dice),
|
| 161 |
+
'best_val_loss_epoch': int(best_val_loss_epoch),
|
| 162 |
+
'best_val_loss': float(best_val_loss),
|
| 163 |
+
'composite_score': float(composite_scores[best_overall_idx]),
|
| 164 |
+
'abnormal_key': abnormal_key,
|
| 165 |
+
'num_classes': int(num_classes),
|
| 166 |
+
'first_valid_epoch': int(first_valid_epoch),
|
| 167 |
+
'total_valid_epochs': int(len(valid_epoch_indices)),
|
| 168 |
+
'beta_threshold': 0.95,
|
| 169 |
+
'total_epochs': int(len(epochs)),
|
| 170 |
+
# Add complete metrics at best epoch
|
| 171 |
+
'best_epoch_metrics': {
|
| 172 |
+
'dice': best_epoch_metrics['dice'],
|
| 173 |
+
'precision': best_epoch_metrics['precision'],
|
| 174 |
+
'recall': best_epoch_metrics['recall'],
|
| 175 |
+
'val_loss': float(val_losses[best_overall_idx]),
|
| 176 |
+
'train_loss': float(history['train_loss'][best_overall_idx]),
|
| 177 |
+
'wce_loss': float(history['wce_loss'][best_overall_idx]),
|
| 178 |
+
'ufd_loss': float(history['ufd_loss'][best_overall_idx]),
|
| 179 |
+
'val_loss_wce': float(history['val_loss_wce'][best_overall_idx]) if 'val_loss_wce' in history else None,
|
| 180 |
+
'val_loss_ufd': float(history['val_loss_ufd'][best_overall_idx]) if 'val_loss_ufd' in history else None,
|
| 181 |
+
'beta_value': float(beta_values[best_overall_idx])
|
| 182 |
+
}
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
# Convert all numpy types to native Python types
|
| 186 |
+
analysis = convert_to_native_types(analysis)
|
| 187 |
+
|
| 188 |
+
return best_overall_epoch, analysis
|
| 189 |
+
|
| 190 |
+
def save_analysis_json(analysis, output_path):
|
| 191 |
+
"""Save analysis results to a JSON file."""
|
| 192 |
+
analysis = convert_to_native_types(analysis)
|
| 193 |
+
with open(output_path, 'w') as f:
|
| 194 |
+
json.dump(analysis, f, indent=2)
|
| 195 |
+
print(f"✓ Analysis saved to: {output_path}")
|
| 196 |
+
|
| 197 |
+
def save_enhanced_history(history, analysis, output_path):
|
| 198 |
+
"""Save enhanced history with best epoch analysis appended."""
|
| 199 |
+
enhanced_history = history.copy()
|
| 200 |
+
enhanced_history['best_epoch_analysis'] = convert_to_native_types(analysis)
|
| 201 |
+
enhanced_history = convert_to_native_types(enhanced_history)
|
| 202 |
+
|
| 203 |
+
with open(output_path, 'w') as f:
|
| 204 |
+
json.dump(enhanced_history, f, indent=2)
|
| 205 |
+
print(f"✓ Enhanced history saved to: {output_path}")
|
| 206 |
+
|
| 207 |
+
def create_training_summary(history, analysis, class_names):
|
| 208 |
+
"""Create a comprehensive training summary for easy parsing."""
|
| 209 |
+
summary = {
|
| 210 |
+
'training_config': {
|
| 211 |
+
'total_epochs': analysis['total_epochs'],
|
| 212 |
+
'num_classes': analysis['num_classes'],
|
| 213 |
+
'class_names': class_names,
|
| 214 |
+
'model_type': 'a U-Net'
|
| 215 |
+
},
|
| 216 |
+
'best_epoch_selection': {
|
| 217 |
+
'overall_best_epoch': analysis['best_overall_epoch'],
|
| 218 |
+
'composite_score': analysis['composite_score'],
|
| 219 |
+
'selection_criteria': {
|
| 220 |
+
'abnormal_wmh_weight': 0.6,
|
| 221 |
+
'ventricles_weight': 0.3,
|
| 222 |
+
'val_loss_weight': 0.1
|
| 223 |
+
}
|
| 224 |
+
},
|
| 225 |
+
'priority_metrics': {
|
| 226 |
+
'abnormal_wmh': {
|
| 227 |
+
'best_epoch': analysis['best_abnormal_epoch'],
|
| 228 |
+
'best_dice': analysis['best_abnormal_dice']
|
| 229 |
+
},
|
| 230 |
+
'ventricles': {
|
| 231 |
+
'best_epoch': analysis['best_ventricles_epoch'],
|
| 232 |
+
'best_dice': analysis['best_ventricles_dice']
|
| 233 |
+
},
|
| 234 |
+
'validation_loss': {
|
| 235 |
+
'best_epoch': analysis['best_val_loss_epoch'],
|
| 236 |
+
'best_loss': analysis['best_val_loss']
|
| 237 |
+
}
|
| 238 |
+
},
|
| 239 |
+
'best_epoch_metrics': analysis['best_epoch_metrics'],
|
| 240 |
+
'training_progression': {
|
| 241 |
+
'final_epoch_metrics': {
|
| 242 |
+
'dice': history['val_metrics'][-1]['dice'],
|
| 243 |
+
'precision': history['val_metrics'][-1]['precision'],
|
| 244 |
+
'recall': history['val_metrics'][-1]['recall'],
|
| 245 |
+
'val_loss': history['val_loss'][-1],
|
| 246 |
+
'train_loss': history['train_loss'][-1]
|
| 247 |
+
},
|
| 248 |
+
'convergence_info': {
|
| 249 |
+
'epochs_trained': len(history['val_loss'])
|
| 250 |
+
}
|
| 251 |
+
}
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
# Add epoch-by-epoch metrics for important classes
|
| 255 |
+
summary['epoch_progression'] = {
|
| 256 |
+
'abnormal_wmh_dice': [m['dice'][analysis['abnormal_key']] for m in history['val_metrics']],
|
| 257 |
+
'ventricles_dice': [m['dice']['class_1'] for m in history['val_metrics']],
|
| 258 |
+
'mean_dice': [m['dice']['mean'] for m in history['val_metrics']],
|
| 259 |
+
'val_loss': history['val_loss'],
|
| 260 |
+
'train_loss': history['train_loss']
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
summary = convert_to_native_types(summary)
|
| 264 |
+
|
| 265 |
+
return summary
|
| 266 |
+
|
| 267 |
+
def plot_training_history(history, save_path='training_history.png'):
|
| 268 |
+
"""Create comprehensive visualization of training history."""
|
| 269 |
+
|
| 270 |
+
num_classes = detect_num_classes(history)
|
| 271 |
+
class_names = get_class_names(num_classes)
|
| 272 |
+
best_epoch, analysis = find_best_epoch(history, num_classes)
|
| 273 |
+
|
| 274 |
+
epochs = range(1, len(history['train_loss']) + 1)
|
| 275 |
+
|
| 276 |
+
# Detect whether new-style history (with val_loss_wce / val_loss_ufd) is present
|
| 277 |
+
has_val_components = 'val_loss_wce' in history and 'val_loss_ufd' in history
|
| 278 |
+
|
| 279 |
+
# Create figure — 3 rows × 3 cols when val components exist, else 2×3
|
| 280 |
+
nrows = 3 if has_val_components else 2
|
| 281 |
+
fig = plt.figure(figsize=(18, nrows * 5))
|
| 282 |
+
gs = fig.add_gridspec(nrows, 3, hspace=0.35, wspace=0.3)
|
| 283 |
+
|
| 284 |
+
# Color scheme
|
| 285 |
+
colors = ['#2E86AB', '#A23B72', '#F18F01', '#C73E1D']
|
| 286 |
+
wce_color = '#4CAF50' # green – WCE
|
| 287 |
+
ufd_color = '#9C27B0' # purple – UFD
|
| 288 |
+
beta_color = '#FF5722' # deep-orange – beta
|
| 289 |
+
|
| 290 |
+
# 1. Training and Validation Loss (combined / weighted)
|
| 291 |
+
ax1 = fig.add_subplot(gs[0, 0])
|
| 292 |
+
ax1.plot(epochs, history['train_loss'], 'o-', linewidth=2, markersize=6,
|
| 293 |
+
color=colors[0], label='Train Loss')
|
| 294 |
+
ax1.plot(epochs, history['val_loss'], 's-', linewidth=2, markersize=6,
|
| 295 |
+
color=colors[2], label='Val Loss')
|
| 296 |
+
if best_epoch:
|
| 297 |
+
ax1.axvline(x=best_epoch, color='red', linestyle='--', linewidth=2,
|
| 298 |
+
alpha=0.7, label=f'Best Epoch ({best_epoch})')
|
| 299 |
+
ax1.set_xlabel('Epoch', fontsize=11, fontweight='bold')
|
| 300 |
+
ax1.set_ylabel('Loss', fontsize=11, fontweight='bold')
|
| 301 |
+
ax1.set_title('Training & Validation Loss\n(Combined Adaptive Loss)', fontsize=13, fontweight='bold')
|
| 302 |
+
ax1.legend(fontsize=9)
|
| 303 |
+
ax1.grid(True, alpha=0.3)
|
| 304 |
+
|
| 305 |
+
# 2. Dice Scores (excluding background)
|
| 306 |
+
ax2 = fig.add_subplot(gs[0, 1])
|
| 307 |
+
for i in range(1, num_classes): # Skip class_0 (background)
|
| 308 |
+
class_key = f'class_{i}'
|
| 309 |
+
dice_scores = [m['dice'][class_key] for m in history['val_metrics']]
|
| 310 |
+
ax2.plot(epochs, dice_scores, 'o-', linewidth=2, markersize=6,
|
| 311 |
+
label=class_names[class_key], color=colors[i % len(colors)])
|
| 312 |
+
|
| 313 |
+
if best_epoch:
|
| 314 |
+
ax2.axvline(x=best_epoch, color='red', linestyle='--', linewidth=2,
|
| 315 |
+
alpha=0.7, label=f'Best Epoch ({best_epoch})')
|
| 316 |
+
ax2.set_xlabel('Epoch', fontsize=11, fontweight='bold')
|
| 317 |
+
ax2.set_ylabel('Dice Score', fontsize=11, fontweight='bold')
|
| 318 |
+
ax2.set_title('Dice Scores by Class', fontsize=13, fontweight='bold')
|
| 319 |
+
ax2.legend(fontsize=9)
|
| 320 |
+
ax2.grid(True, alpha=0.3)
|
| 321 |
+
ax2.set_ylim([0, 1])
|
| 322 |
+
|
| 323 |
+
# 3. Precision Scores (excluding background)
|
| 324 |
+
ax3 = fig.add_subplot(gs[0, 2])
|
| 325 |
+
for i in range(1, num_classes):
|
| 326 |
+
class_key = f'class_{i}'
|
| 327 |
+
precision_scores = [m['precision'][class_key] for m in history['val_metrics']]
|
| 328 |
+
ax3.plot(epochs, precision_scores, 's-', linewidth=2, markersize=5,
|
| 329 |
+
label=class_names[class_key], color=colors[i % len(colors)])
|
| 330 |
+
|
| 331 |
+
if best_epoch:
|
| 332 |
+
ax3.axvline(x=best_epoch, color='red', linestyle='--', linewidth=2, alpha=0.7)
|
| 333 |
+
ax3.set_xlabel('Epoch', fontsize=11, fontweight='bold')
|
| 334 |
+
ax3.set_ylabel('Precision', fontsize=11, fontweight='bold')
|
| 335 |
+
ax3.set_title('Precision by Class', fontsize=13, fontweight='bold')
|
| 336 |
+
ax3.legend(fontsize=9)
|
| 337 |
+
ax3.grid(True, alpha=0.3)
|
| 338 |
+
ax3.set_ylim([0, 1])
|
| 339 |
+
|
| 340 |
+
# 4. Recall Scores (excluding background)
|
| 341 |
+
ax4 = fig.add_subplot(gs[1, 0])
|
| 342 |
+
for i in range(1, num_classes):
|
| 343 |
+
class_key = f'class_{i}'
|
| 344 |
+
recall_scores = [m['recall'][class_key] for m in history['val_metrics']]
|
| 345 |
+
ax4.plot(epochs, recall_scores, '^-', linewidth=2, markersize=5,
|
| 346 |
+
label=class_names[class_key], color=colors[i % len(colors)])
|
| 347 |
+
|
| 348 |
+
if best_epoch:
|
| 349 |
+
ax4.axvline(x=best_epoch, color='red', linestyle='--', linewidth=2, alpha=0.7)
|
| 350 |
+
ax4.set_xlabel('Epoch', fontsize=11, fontweight='bold')
|
| 351 |
+
ax4.set_ylabel('Recall', fontsize=11, fontweight='bold')
|
| 352 |
+
ax4.set_title('Recall by Class', fontsize=13, fontweight='bold')
|
| 353 |
+
ax4.legend(fontsize=9)
|
| 354 |
+
ax4.grid(True, alpha=0.3)
|
| 355 |
+
ax4.set_ylim([0, 1])
|
| 356 |
+
|
| 357 |
+
# 5. Mean Metrics
|
| 358 |
+
ax5 = fig.add_subplot(gs[1, 1])
|
| 359 |
+
mean_dice = [m['dice']['mean'] for m in history['val_metrics']]
|
| 360 |
+
mean_precision = [m['precision']['mean'] for m in history['val_metrics']]
|
| 361 |
+
mean_recall = [m['recall']['mean'] for m in history['val_metrics']]
|
| 362 |
+
|
| 363 |
+
ax5.plot(epochs, mean_dice, 'o-', linewidth=2, markersize=6,
|
| 364 |
+
color=colors[0], label='Mean Dice')
|
| 365 |
+
ax5.plot(epochs, mean_precision, 's-', linewidth=2, markersize=5,
|
| 366 |
+
color=colors[1], label='Mean Precision')
|
| 367 |
+
ax5.plot(epochs, mean_recall, '^-', linewidth=2, markersize=5,
|
| 368 |
+
color=colors[2], label='Mean Recall')
|
| 369 |
+
|
| 370 |
+
if best_epoch:
|
| 371 |
+
ax5.axvline(x=best_epoch, color='red', linestyle='--', linewidth=2, alpha=0.7)
|
| 372 |
+
ax5.set_xlabel('Epoch', fontsize=11, fontweight='bold')
|
| 373 |
+
ax5.set_ylabel('Score', fontsize=11, fontweight='bold')
|
| 374 |
+
ax5.set_title('Mean Validation Metrics', fontsize=13, fontweight='bold')
|
| 375 |
+
ax5.legend(fontsize=9)
|
| 376 |
+
ax5.grid(True, alpha=0.3)
|
| 377 |
+
ax5.set_ylim([0, 1])
|
| 378 |
+
|
| 379 |
+
# ── New Row 3 plots (only when val components are available) ──────────────
|
| 380 |
+
if has_val_components:
|
| 381 |
+
# 7. Training Loss Components (WCE vs UFD, train-side)
|
| 382 |
+
ax7 = fig.add_subplot(gs[2, 0])
|
| 383 |
+
ax7.plot(epochs, list(1*np.array(history['wce_loss'])), 'o-', linewidth=2, markersize=5,
|
| 384 |
+
color=wce_color, label='Train WCE Loss x10')
|
| 385 |
+
ax7.plot(epochs, history['ufd_loss'], 's-', linewidth=2, markersize=5,
|
| 386 |
+
color=ufd_color, label='Train UFD Loss')
|
| 387 |
+
ax7.plot(epochs, list(1*np.array(history['val_loss_wce'])), 'o--', linewidth=1.5, markersize=4,
|
| 388 |
+
color=wce_color, alpha=0.6, label='Val WCE Loss x10')
|
| 389 |
+
ax7.plot(epochs, history['val_loss_ufd'], 's--', linewidth=1.5, markersize=4,
|
| 390 |
+
color=ufd_color, alpha=0.6, label='Val UFD Loss')
|
| 391 |
+
if best_epoch:
|
| 392 |
+
ax7.axvline(x=best_epoch, color='red', linestyle='--', linewidth=2,
|
| 393 |
+
alpha=0.7, label=f'Best Epoch ({best_epoch})')
|
| 394 |
+
ax7.set_xlabel('Epoch', fontsize=11, fontweight='bold')
|
| 395 |
+
ax7.set_ylabel('Loss', fontsize=11, fontweight='bold')
|
| 396 |
+
ax7.set_title('Loss Components: WCE vs UFD\n(Train solid · Val dashed)', fontsize=13, fontweight='bold')
|
| 397 |
+
ax7.legend(fontsize=8)
|
| 398 |
+
ax7.grid(True, alpha=0.3)
|
| 399 |
+
|
| 400 |
+
# 8. Weighted contribution of each loss to the total loss
|
| 401 |
+
ax8 = fig.add_subplot(gs[2, 1])
|
| 402 |
+
beta_values = history.get('beta_value', [e / len(epochs) for e in epochs])
|
| 403 |
+
betas = np.array(beta_values)
|
| 404 |
+
ones = np.ones_like(betas)
|
| 405 |
+
|
| 406 |
+
# Weighted contributions
|
| 407 |
+
train_wce_contrib = (ones - betas) * np.array(history['wce_loss'])
|
| 408 |
+
train_ufd_contrib = betas * np.array(history['ufd_loss'])
|
| 409 |
+
val_wce_contrib = (ones - betas) * np.array(history['val_loss_wce'])
|
| 410 |
+
val_ufd_contrib = betas * np.array(history['val_loss_ufd'])
|
| 411 |
+
|
| 412 |
+
ax8.stackplot(list(epochs),
|
| 413 |
+
train_wce_contrib, train_ufd_contrib,
|
| 414 |
+
labels=['(1−β)·WCE [train] x10', 'β·UFD [train]'],
|
| 415 |
+
colors=[wce_color, ufd_color], alpha=0.55)
|
| 416 |
+
ax8.plot(epochs, history['train_loss'], 'k-', linewidth=1.5, label='Total Train Loss')
|
| 417 |
+
|
| 418 |
+
# Overlay val contributions as lines for clarity
|
| 419 |
+
ax8.plot(epochs, val_wce_contrib, '--', color=wce_color, linewidth=1.5,
|
| 420 |
+
alpha=0.8, label='(1−β)·WCE [val] x10')
|
| 421 |
+
ax8.plot(epochs, val_ufd_contrib, '--', color=ufd_color, linewidth=1.5,
|
| 422 |
+
alpha=0.8, label='β·UFD [val]')
|
| 423 |
+
if best_epoch:
|
| 424 |
+
ax8.axvline(x=best_epoch, color='red', linestyle='--', linewidth=2, alpha=0.7)
|
| 425 |
+
ax8.set_xlabel('Epoch', fontsize=11, fontweight='bold')
|
| 426 |
+
ax8.set_ylabel('Weighted Loss', fontsize=11, fontweight='bold')
|
| 427 |
+
ax8.set_title('Weighted Loss Contributions\n(Adaptive β Schedule)', fontsize=13, fontweight='bold')
|
| 428 |
+
ax8.legend(fontsize=8)
|
| 429 |
+
ax8.grid(True, alpha=0.3)
|
| 430 |
+
|
| 431 |
+
# # 9. Beta schedule
|
| 432 |
+
# ax9 = fig.add_subplot(gs[2, 2])
|
| 433 |
+
# ax9.plot(list(epochs), betas, 'o-', linewidth=2, markersize=5,
|
| 434 |
+
# color=beta_color, label='β (epoch/total)')
|
| 435 |
+
# ax9.fill_between(list(epochs), betas, alpha=0.15, color=beta_color)
|
| 436 |
+
# ax9.axhline(y=0.95, color='gray', linestyle=':', linewidth=1.5,
|
| 437 |
+
# label='β = 0.95 threshold')
|
| 438 |
+
# if best_epoch:
|
| 439 |
+
# ax9.axvline(x=best_epoch, color='red', linestyle='--', linewidth=2,
|
| 440 |
+
# alpha=0.7, label=f'Best Epoch ({best_epoch})')
|
| 441 |
+
# ax9.set_xlabel('Epoch', fontsize=11, fontweight='bold')
|
| 442 |
+
# ax9.set_ylabel('β value', fontsize=11, fontweight='bold')
|
| 443 |
+
# ax9.set_title('Beta Schedule\n(WCE → UFD transition)', fontsize=13, fontweight='bold')
|
| 444 |
+
# ax9.set_ylim([0, 1.05])
|
| 445 |
+
# ax9.legend(fontsize=9)
|
| 446 |
+
# ax9.grid(True, alpha=0.3)
|
| 447 |
+
|
| 448 |
+
# 6. Analysis Summary
|
| 449 |
+
ax6 = fig.add_subplot(gs[1, 2])
|
| 450 |
+
ax6.axis('off')
|
| 451 |
+
|
| 452 |
+
if analysis:
|
| 453 |
+
abnormal_class = class_names[analysis['abnormal_key']]
|
| 454 |
+
best_epoch_idx = analysis['best_overall_epoch'] - 1
|
| 455 |
+
|
| 456 |
+
# Get dice scores for all classes at the best epoch
|
| 457 |
+
best_epoch_metrics = history['val_metrics'][best_epoch_idx]['dice']
|
| 458 |
+
|
| 459 |
+
# Build dice scores text (excluding background)
|
| 460 |
+
dice_scores_text = ""
|
| 461 |
+
for i in range(1, num_classes):
|
| 462 |
+
class_key = f'class_{i}'
|
| 463 |
+
dice_value = best_epoch_metrics[class_key]
|
| 464 |
+
dice_scores_text += f" {class_names[class_key]}: {dice_value:.4f}\n"
|
| 465 |
+
|
| 466 |
+
summary_text = f"""
|
| 467 |
+
TRAINING ANALYSIS SUMMARY
|
| 468 |
+
{'=' * 40}
|
| 469 |
+
|
| 470 |
+
Model: a U-Net
|
| 471 |
+
Number of Classes: {analysis['num_classes']}
|
| 472 |
+
Total Epochs: {len(epochs)}
|
| 473 |
+
|
| 474 |
+
BEST OVERALL EPOCH: {analysis['best_overall_epoch']}
|
| 475 |
+
(Composite Score: {analysis['composite_score']:.4f})
|
| 476 |
+
|
| 477 |
+
Dice Scores at Best Epoch:
|
| 478 |
+
{dice_scores_text}
|
| 479 |
+
{'─' * 40}
|
| 480 |
+
Priority Metrics:
|
| 481 |
+
{'─' * 40}
|
| 482 |
+
|
| 483 |
+
Best {abnormal_class} Dice:
|
| 484 |
+
Epoch {analysis['best_abnormal_epoch']}: {analysis['best_abnormal_dice']:.4f}
|
| 485 |
+
|
| 486 |
+
Best Ventricles Dice:
|
| 487 |
+
Epoch {analysis['best_ventricles_epoch']}: {analysis['best_ventricles_dice']:.4f}
|
| 488 |
+
|
| 489 |
+
Best Validation Loss:
|
| 490 |
+
Epoch {analysis['best_val_loss_epoch']}: {analysis['best_val_loss']:.4f}
|
| 491 |
+
|
| 492 |
+
{'─' * 40}
|
| 493 |
+
Loss at Best Epoch:
|
| 494 |
+
Train WCE: {analysis['best_epoch_metrics']['wce_loss']:.4f}
|
| 495 |
+
Train UFD: {analysis['best_epoch_metrics']['ufd_loss']:.4f}"""
|
| 496 |
+
|
| 497 |
+
if analysis['best_epoch_metrics'].get('val_loss_wce') is not None:
|
| 498 |
+
summary_text += f"""
|
| 499 |
+
Val WCE: {analysis['best_epoch_metrics']['val_loss_wce']:.4f}
|
| 500 |
+
Val UFD: {analysis['best_epoch_metrics']['val_loss_ufd']:.4f}"""
|
| 501 |
+
|
| 502 |
+
summary_text += f"""
|
| 503 |
+
β value: {analysis['best_epoch_metrics']['beta_value']:.4f}
|
| 504 |
+
|
| 505 |
+
{'─' * 40}
|
| 506 |
+
Scoring Weights:
|
| 507 |
+
{abnormal_class}: 60%
|
| 508 |
+
Ventricles: 30%
|
| 509 |
+
Val Loss: 10%
|
| 510 |
+
"""
|
| 511 |
+
|
| 512 |
+
ax6.text(0.05, 0.95, summary_text, transform=ax6.transAxes,
|
| 513 |
+
fontsize=9, verticalalignment='top', fontfamily='monospace',
|
| 514 |
+
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.3))
|
| 515 |
+
|
| 516 |
+
plt.suptitle('a U-Net Training History - Comprehensive Analysis\n'
|
| 517 |
+
'(Adaptive Loss: WCE + UFD with β schedule)',
|
| 518 |
+
fontsize=16, fontweight='bold', y=0.998)
|
| 519 |
+
|
| 520 |
+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 521 |
+
print(f"✓ Visualization saved to: {save_path}")
|
| 522 |
+
# plt.show()
|
| 523 |
+
|
| 524 |
+
return analysis
|
| 525 |
+
|
| 526 |
+
def print_detailed_analysis(analysis):
|
| 527 |
+
"""Print detailed analysis to console."""
|
| 528 |
+
if not analysis:
|
| 529 |
+
print("No analysis available.")
|
| 530 |
+
return
|
| 531 |
+
|
| 532 |
+
print("\n" + "="*60)
|
| 533 |
+
print("DETAILED TRAINING ANALYSIS - a U-NET")
|
| 534 |
+
print("="*60)
|
| 535 |
+
print(f"\n📊 Number of Classes: {analysis['num_classes']}")
|
| 536 |
+
print(f"\n🏆 RECOMMENDED EPOCH: {analysis['best_overall_epoch']}")
|
| 537 |
+
print(f" Composite Score: {analysis['composite_score']:.4f}")
|
| 538 |
+
print("\n" + "-"*60)
|
| 539 |
+
print("Individual Best Performances:")
|
| 540 |
+
print("-"*60)
|
| 541 |
+
print(f"\n🎯 Abnormal WMH Dice (TOP PRIORITY):")
|
| 542 |
+
print(f" Best Epoch: {analysis['best_abnormal_epoch']}")
|
| 543 |
+
print(f" Best Score: {analysis['best_abnormal_dice']:.4f}")
|
| 544 |
+
print(f"\n🫀 Ventricles Dice (SECONDARY):")
|
| 545 |
+
print(f" Best Epoch: {analysis['best_ventricles_epoch']}")
|
| 546 |
+
print(f" Best Score: {analysis['best_ventricles_dice']:.4f}")
|
| 547 |
+
print(f"\n📉 Validation Loss (TERTIARY):")
|
| 548 |
+
print(f" Best Epoch: {analysis['best_val_loss_epoch']}")
|
| 549 |
+
print(f" Lowest Loss: {analysis['best_val_loss']:.4f}")
|
| 550 |
+
print("\n" + "="*60)
|
| 551 |
+
print("\nNote: Best overall epoch is calculated using weighted scoring:")
|
| 552 |
+
print(" • Abnormal WMH Dice: 60%")
|
| 553 |
+
print(" • Ventricles Dice: 30%")
|
| 554 |
+
print(" • Validation Loss: 10%")
|
| 555 |
+
print("="*60 + "\n")
|
| 556 |
+
|
| 557 |
+
def main_viz(filepath='history_sample.json', save_outputs=True):
|
| 558 |
+
"""Main execution function."""
|
| 559 |
+
# Load history
|
| 560 |
+
print(f"Loading training history from: {filepath}")
|
| 561 |
+
history = load_history(filepath)
|
| 562 |
+
|
| 563 |
+
print(f"✓ Loaded {len(history['train_loss'])} epochs of training data")
|
| 564 |
+
|
| 565 |
+
# Get output directory
|
| 566 |
+
out_dir = os.path.dirname(filepath)
|
| 567 |
+
|
| 568 |
+
# Detect number of classes and get class names
|
| 569 |
+
num_classes = detect_num_classes(history)
|
| 570 |
+
class_names = get_class_names(num_classes)
|
| 571 |
+
|
| 572 |
+
# Find best epoch and create analysis
|
| 573 |
+
best_epoch, analysis = find_best_epoch(history, num_classes)
|
| 574 |
+
|
| 575 |
+
# Create visualization
|
| 576 |
+
plot_training_history(history, save_path=os.path.join(out_dir, 'a_unet_training_analysis.png'))
|
| 577 |
+
|
| 578 |
+
# Print detailed analysis
|
| 579 |
+
print_detailed_analysis(analysis)
|
| 580 |
+
|
| 581 |
+
if save_outputs:
|
| 582 |
+
print("\n" + "="*60)
|
| 583 |
+
print("SAVING ANALYSIS OUTPUTS")
|
| 584 |
+
print("="*60)
|
| 585 |
+
|
| 586 |
+
# 1. Save standalone analysis JSON
|
| 587 |
+
analysis_path = os.path.join(out_dir, 'best_epoch_analysis.json')
|
| 588 |
+
save_analysis_json(analysis, analysis_path)
|
| 589 |
+
|
| 590 |
+
# 2. Save enhanced history with analysis appended
|
| 591 |
+
enhanced_history_path = os.path.join(out_dir, 'history_with_analysis.json')
|
| 592 |
+
save_enhanced_history(history, analysis, enhanced_history_path)
|
| 593 |
+
|
| 594 |
+
# 3. Save training summary
|
| 595 |
+
summary = create_training_summary(history, analysis, class_names)
|
| 596 |
+
summary_path = os.path.join(out_dir, 'training_summary.json')
|
| 597 |
+
with open(summary_path, 'w') as f:
|
| 598 |
+
json.dump(summary, f, indent=2)
|
| 599 |
+
print(f"✓ Training summary saved to: {summary_path}")
|
| 600 |
+
|
| 601 |
+
print("\n" + "="*60)
|
| 602 |
+
print("ALL OUTPUTS SAVED SUCCESSFULLY")
|
| 603 |
+
print("="*60)
|
| 604 |
+
print("\nGenerated files:")
|
| 605 |
+
print(f" 1. unet_training_analysis.png - Visualization")
|
| 606 |
+
print(f" 2. best_epoch_analysis.json - Best epoch analysis")
|
| 607 |
+
print(f" 3. history_with_analysis.json - Enhanced history")
|
| 608 |
+
print(f" 4. training_summary.json - Comprehensive training summary")
|
| 609 |
+
print("="*60 + "\n")
|
| 610 |
+
|
| 611 |
+
return analysis, history
|
| 612 |
+
|
| 613 |
+
if __name__ == "__main__":
|
| 614 |
+
|
| 615 |
+
# experiment_dir = '/mnt/e/MBashiri/ours_articles/Paper#2/Development/results_unet_baseline_fold_0/models'
|
| 616 |
+
# scenario = 'standard_4class'
|
| 617 |
+
# fold_num = 'fold_0'
|
| 618 |
+
# filepath = os.path.join(experiment_dir, scenario, fold_num, 'history.json')
|
| 619 |
+
|
| 620 |
+
# main_viz(filepath=filepath, save_outputs=True)
|
| 621 |
+
|
| 622 |
+
for fold in range(5):
|
| 623 |
+
|
| 624 |
+
# Skip folds:
|
| 625 |
+
if fold in list(np.array([0, 2, 3, 4])):
|
| 626 |
+
continue
|
| 627 |
+
|
| 628 |
+
for variant in range(5):
|
| 629 |
+
|
| 630 |
+
# # Skip variants:
|
| 631 |
+
if variant not in list(np.array([1])):
|
| 632 |
+
continue
|
| 633 |
+
|
| 634 |
+
experiment_dir = f'/mnt/e/MBashiri/ours_articles/Paper#4/Development/results_fold_{fold}_var_{variant}_zscore2/models'
|
| 635 |
+
scenario = 'standard_3class'
|
| 636 |
+
fold_num = f'fold_{fold}'
|
| 637 |
+
filepath = os.path.join(experiment_dir, scenario, fold_num, 'history.json')
|
| 638 |
+
|
| 639 |
+
main_viz(filepath=filepath)
|
| 640 |
+
|
models/for_WMH_Vent/model_training_scripts/p4_variant_all_net.py
ADDED
|
@@ -0,0 +1,1051 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
P4 - All U-Net models with Adaptive Loss (WCE + UFL)
|
| 3 |
+
|
| 4 |
+
WMH and Ventricles Segmentation with U-Net Models - Journal Paper Implementation
|
| 5 |
+
Three-class segmentation: Background vs Ventricles vs Abnormal WMH
|
| 6 |
+
Professional results saving and visualization for publication
|
| 7 |
+
|
| 8 |
+
This relates to our article:
|
| 9 |
+
"Deep Learning-Based Neuroanatomical Profiling Reveals Detailed Brain Changes:
|
| 10 |
+
A Large-Scale Multiple Sclerosis Study"
|
| 11 |
+
|
| 12 |
+
Features:
|
| 13 |
+
- Various U-Net architecture
|
| 14 |
+
- Weighted Categorical Cross-Entropy loss
|
| 15 |
+
- Unified Focal loss
|
| 16 |
+
- One-hot encoded targets
|
| 17 |
+
- Class weight computation per fold
|
| 18 |
+
|
| 19 |
+
Authors:
|
| 20 |
+
"Mahdi Bashiri Bawil, Mousa Shamsi, Abolhassan Shakeri Bavil"
|
| 21 |
+
|
| 22 |
+
Developer:
|
| 23 |
+
"Mahdi Bashiri Bawil"
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
import tensorflow as tf
|
| 27 |
+
import os
|
| 28 |
+
import time
|
| 29 |
+
import numpy as np
|
| 30 |
+
import matplotlib.pyplot as plt
|
| 31 |
+
from pathlib import Path
|
| 32 |
+
from tqdm import tqdm
|
| 33 |
+
import json
|
| 34 |
+
|
| 35 |
+
# Import data loader
|
| 36 |
+
from p4_data_loader import DataConfig, P2DataLoader
|
| 37 |
+
|
| 38 |
+
# Import utilities from baseline
|
| 39 |
+
from utility_functions import (
|
| 40 |
+
clear_gpu_memory,
|
| 41 |
+
get_gpu_memory_info,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# Import class weights utility
|
| 45 |
+
from p4_compute_class_weights import compute_and_save_class_weights, load_class_weights
|
| 46 |
+
|
| 47 |
+
print("TensorFlow Version:", tf.__version__)
|
| 48 |
+
|
| 49 |
+
###################### GPU Configuration ######################
|
| 50 |
+
|
| 51 |
+
# Configure GPU memory growth
|
| 52 |
+
physical_devices = tf.config.list_physical_devices('GPU')
|
| 53 |
+
if physical_devices:
|
| 54 |
+
try:
|
| 55 |
+
for device in physical_devices:
|
| 56 |
+
tf.config.experimental.set_memory_growth(device, True)
|
| 57 |
+
print("✅ GPU memory growth enabled")
|
| 58 |
+
print(f" Available GPUs: {len(physical_devices)}")
|
| 59 |
+
except RuntimeError as e:
|
| 60 |
+
print(f"GPU configuration error: {e}")
|
| 61 |
+
else:
|
| 62 |
+
print("⚠️ No GPU detected - training will be slow")
|
| 63 |
+
|
| 64 |
+
"""
|
| 65 |
+
GPU Memory Management for Sequential Experiments
|
| 66 |
+
To properly release memory between experiments
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
###################### Target Preparation ######################
|
| 70 |
+
|
| 71 |
+
def prepare_inputs(paired_input, target_mask, num_classes):
|
| 72 |
+
"""
|
| 73 |
+
Prepare inputs for training
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
paired_input: (bs, 256, 512, 1) with FLAIR + mask
|
| 77 |
+
target_mask: (bs, 256, 256) with class labels [0, num_classes-1]
|
| 78 |
+
num_classes: number of classes
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
flair_normalized: FLAIR normalized to [-1, 1]
|
| 82 |
+
target_onehot: One-hot encoded mask (bs, 256, 256, num_classes)
|
| 83 |
+
"""
|
| 84 |
+
# Extract FLAIR, previously normalized to [-1, 1]
|
| 85 |
+
flair_normalized = paired_input[:, :, :256, :]
|
| 86 |
+
|
| 87 |
+
# One-hot encode target
|
| 88 |
+
target_onehot = tf.one_hot(target_mask, depth=num_classes, dtype=tf.float32)
|
| 89 |
+
|
| 90 |
+
return flair_normalized, target_onehot
|
| 91 |
+
|
| 92 |
+
###################### Metrics Calculation ######################
|
| 93 |
+
|
| 94 |
+
def compute_classwise_metrics(all_val_true, all_val_pred, num_classes, exclude_class=None):
|
| 95 |
+
"""
|
| 96 |
+
Compute class-wise Dice, Precision, and Recall for validation predictions.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
all_val_true: List of one-hot encoded ground truth tensors
|
| 100 |
+
all_val_pred: List of softmax output tensors from model
|
| 101 |
+
num_classes: Number of classes (3 or 4)
|
| 102 |
+
exclude_class: Class to exclude from metric calculation (e.g., 2 for background)
|
| 103 |
+
|
| 104 |
+
Returns:
|
| 105 |
+
Dictionary containing class-wise and mean metrics
|
| 106 |
+
"""
|
| 107 |
+
# Concatenate all batches
|
| 108 |
+
y_true_concat = tf.concat(all_val_true, axis=0) # Shape: (N, H, W, num_classes)
|
| 109 |
+
y_pred_concat = tf.concat(all_val_pred, axis=0) # Shape: (N, H, W, num_classes)
|
| 110 |
+
|
| 111 |
+
# Flatten spatial dimensions: (N*H*W, num_classes)
|
| 112 |
+
y_true_flat = tf.reshape(y_true_concat, [-1, num_classes])
|
| 113 |
+
y_pred_flat = tf.reshape(y_pred_concat, [-1, num_classes])
|
| 114 |
+
|
| 115 |
+
# Convert predictions to one-hot (argmax)
|
| 116 |
+
y_pred_classes = tf.argmax(y_pred_flat, axis=-1)
|
| 117 |
+
y_pred_onehot = tf.one_hot(y_pred_classes, depth=num_classes)
|
| 118 |
+
|
| 119 |
+
# Convert to numpy for easier computation
|
| 120 |
+
y_true_np = y_true_flat.numpy()
|
| 121 |
+
y_pred_np = y_pred_onehot.numpy()
|
| 122 |
+
|
| 123 |
+
metrics = {
|
| 124 |
+
'dice': {},
|
| 125 |
+
'precision': {},
|
| 126 |
+
'recall': {}
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
classes_to_evaluate = [c for c in range(num_classes) if c != exclude_class]
|
| 130 |
+
|
| 131 |
+
for class_idx in classes_to_evaluate:
|
| 132 |
+
# Extract binary masks for this class
|
| 133 |
+
true_class = y_true_np[:, class_idx]
|
| 134 |
+
pred_class = y_pred_np[:, class_idx]
|
| 135 |
+
|
| 136 |
+
# True Positives, False Positives, False Negatives
|
| 137 |
+
TP = np.sum((true_class == 1) & (pred_class == 1))
|
| 138 |
+
FP = np.sum((true_class == 0) & (pred_class == 1))
|
| 139 |
+
FN = np.sum((true_class == 1) & (pred_class == 0))
|
| 140 |
+
|
| 141 |
+
# Dice Score: 2*TP / (2*TP + FP + FN)
|
| 142 |
+
dice = (2 * TP) / (2 * TP + FP + FN + 1e-7)
|
| 143 |
+
|
| 144 |
+
# Precision: TP / (TP + FP)
|
| 145 |
+
precision = TP / (TP + FP + 1e-7)
|
| 146 |
+
|
| 147 |
+
# Recall (Sensitivity): TP / (TP + FN)
|
| 148 |
+
recall = TP / (TP + FN + 1e-7)
|
| 149 |
+
|
| 150 |
+
metrics['dice'][f'class_{class_idx}'] = float(dice)
|
| 151 |
+
metrics['precision'][f'class_{class_idx}'] = float(precision)
|
| 152 |
+
metrics['recall'][f'class_{class_idx}'] = float(recall)
|
| 153 |
+
|
| 154 |
+
# Compute mean metrics (excluding the excluded class)
|
| 155 |
+
metrics['dice']['mean'] = np.mean([v for v in metrics['dice'].values()])
|
| 156 |
+
metrics['precision']['mean'] = np.mean([v for v in metrics['precision'].values()])
|
| 157 |
+
metrics['recall']['mean'] = np.mean([v for v in metrics['recall'].values()])
|
| 158 |
+
|
| 159 |
+
return metrics
|
| 160 |
+
|
| 161 |
+
###################### Experiment Configuration ######################
|
| 162 |
+
|
| 163 |
+
class ExperimentConfig:
|
| 164 |
+
"""Configuration for a Specific U-Net experiment"""
|
| 165 |
+
|
| 166 |
+
def __init__(self,
|
| 167 |
+
variant: int = 1,
|
| 168 |
+
preprocessing: str = 'standard',
|
| 169 |
+
class_scenario: str = '3class',
|
| 170 |
+
fold_id: int = 0,
|
| 171 |
+
architecture_name: str = 'unet'
|
| 172 |
+
):
|
| 173 |
+
|
| 174 |
+
# Experiment identification
|
| 175 |
+
self.variant = variant
|
| 176 |
+
self.preprocessing = preprocessing # 'standard' or 'zoomed'
|
| 177 |
+
self.class_scenario = class_scenario # '3class' or '4class'
|
| 178 |
+
self.fold_id = fold_id
|
| 179 |
+
self.architecture_name = architecture_name
|
| 180 |
+
|
| 181 |
+
# Experiment name
|
| 182 |
+
self.exp_name = f"exp_{architecture_name}_{preprocessing}_{class_scenario}_fold{fold_id}"
|
| 183 |
+
|
| 184 |
+
# Number of classes
|
| 185 |
+
self.num_classes = 3 if class_scenario == '3class' else 4
|
| 186 |
+
|
| 187 |
+
# Training hyperparameters
|
| 188 |
+
self.batch_size = 4
|
| 189 |
+
self.img_width = 256
|
| 190 |
+
self.img_height = 256
|
| 191 |
+
self.epochs = 60
|
| 192 |
+
|
| 193 |
+
# Optimizer parameters
|
| 194 |
+
self.learning_rate = 2e-4
|
| 195 |
+
self.beta_1 = 0.9
|
| 196 |
+
|
| 197 |
+
# Adaptive loss parameters
|
| 198 |
+
self.focal_gamma = 0.5 # Focal loss focusing parameter
|
| 199 |
+
self.beta_threshold = 0.25 # Transition at epoch 15/60
|
| 200 |
+
self.beta_smoothness = 0.02 # Transition width
|
| 201 |
+
self.use_focal_alpha = True # Use class weights in focal loss
|
| 202 |
+
|
| 203 |
+
# ReduceLROnPlateau parameters
|
| 204 |
+
self.lr_patience = 5 # Wait 5 epochs before reducing
|
| 205 |
+
self.lr_reduction_factor = 0.5 # Reduce LR by half
|
| 206 |
+
self.lr_min = 1e-7 # Don't go below this
|
| 207 |
+
self.lr_monitor = 'val_loss' # Or 'val_dice_mean'
|
| 208 |
+
|
| 209 |
+
# Paths
|
| 210 |
+
self.results_dir = Path(f"results_fold_{fold_id}_var_{variant}_zscore3")
|
| 211 |
+
self.models_dir = self.results_dir / "models" / f"{preprocessing}_{class_scenario}"
|
| 212 |
+
self.figures_dir = self.results_dir / "figures" / f"{preprocessing}_{class_scenario}" / f"fold_{fold_id}"
|
| 213 |
+
self.logs_dir = self.results_dir / "logs" / f"{preprocessing}_{class_scenario}" / f"fold_{fold_id}"
|
| 214 |
+
|
| 215 |
+
# Create directories
|
| 216 |
+
self.models_dir.mkdir(parents=True, exist_ok=True)
|
| 217 |
+
self.figures_dir.mkdir(parents=True, exist_ok=True)
|
| 218 |
+
self.logs_dir.mkdir(parents=True, exist_ok=True)
|
| 219 |
+
|
| 220 |
+
# Checkpoint configuration
|
| 221 |
+
self.checkpoint_dir = self.models_dir / f"fold_{fold_id}"
|
| 222 |
+
self.checkpoint_dir.mkdir(exist_ok=True)
|
| 223 |
+
|
| 224 |
+
# Class weights directory
|
| 225 |
+
self.weights_dir = Path("class_weights")
|
| 226 |
+
self.weights_dir.mkdir(exist_ok=True)
|
| 227 |
+
|
| 228 |
+
# Save configuration
|
| 229 |
+
self.save_config()
|
| 230 |
+
|
| 231 |
+
def save_config(self):
|
| 232 |
+
"""Save experiment configuration to JSON"""
|
| 233 |
+
config_dict = {
|
| 234 |
+
'variant': self.variant,
|
| 235 |
+
'variant_name': f'{self.architecture_name}',
|
| 236 |
+
'preprocessing': self.preprocessing,
|
| 237 |
+
'class_scenario': self.class_scenario,
|
| 238 |
+
'fold_id': self.fold_id,
|
| 239 |
+
'num_classes': self.num_classes,
|
| 240 |
+
'batch_size': self.batch_size,
|
| 241 |
+
'epochs': self.epochs,
|
| 242 |
+
'focal_gamma': self.focal_gamma,
|
| 243 |
+
'beta_threshold': self.beta_threshold,
|
| 244 |
+
'beta_smoothness': self.beta_smoothness,
|
| 245 |
+
'learning_rate': self.learning_rate,
|
| 246 |
+
'beta_1': self.beta_1,
|
| 247 |
+
'loss': 'Phase-transitioning segmentation loss (WCE → UFD)'
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
config_file = self.checkpoint_dir / "config.json"
|
| 251 |
+
with open(config_file, 'w') as f:
|
| 252 |
+
json.dump(config_dict, f, indent=2)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
###################### Beta Scheduling ######################
|
| 256 |
+
|
| 257 |
+
def smooth_step(x, threshold=0.5, smoothness=0.1):
|
| 258 |
+
"""
|
| 259 |
+
Smooth step function for phase transition
|
| 260 |
+
|
| 261 |
+
Creates smooth transition around threshold value using sigmoid.
|
| 262 |
+
|
| 263 |
+
Args:
|
| 264 |
+
x: Current progress (typically epoch / total_epochs)
|
| 265 |
+
threshold: Center point of transition (e.g., 0.5 for epoch 25/50)
|
| 266 |
+
smoothness: Width of transition (smaller = sharper, larger = smoother)
|
| 267 |
+
|
| 268 |
+
Returns:
|
| 269 |
+
Value in [0, 1] representing transition progress
|
| 270 |
+
- x << threshold: returns ≈ 0
|
| 271 |
+
- x ≈ threshold: returns ≈ 0.5
|
| 272 |
+
- x >> threshold: returns ≈ 1
|
| 273 |
+
|
| 274 |
+
Example:
|
| 275 |
+
epoch_progress = 0.3 # Epoch 15/50
|
| 276 |
+
beta = smooth_step(0.3, threshold=0.5, smoothness=0.1)
|
| 277 |
+
# beta ≈ 0.05 (mostly phase 1)
|
| 278 |
+
|
| 279 |
+
epoch_progress = 0.5 # Epoch 25/50
|
| 280 |
+
beta = smooth_step(0.5, threshold=0.5, smoothness=0.1)
|
| 281 |
+
# beta ≈ 0.5 (equal mix)
|
| 282 |
+
|
| 283 |
+
epoch_progress = 0.7 # Epoch 35/50
|
| 284 |
+
beta = smooth_step(0.7, threshold=0.5, smoothness=0.1)
|
| 285 |
+
# beta ≈ 0.95 (mostly phase 2)
|
| 286 |
+
"""
|
| 287 |
+
# Sigmoid centered at threshold
|
| 288 |
+
# (x - threshold) / smoothness controls steepness
|
| 289 |
+
return tf.sigmoid((x - threshold) / smoothness)
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def compute_beta_schedule(current_epoch, total_epochs,
|
| 293 |
+
threshold=0.5, smoothness=0.1):
|
| 294 |
+
"""
|
| 295 |
+
Compute beta value for current epoch
|
| 296 |
+
|
| 297 |
+
Args:
|
| 298 |
+
current_epoch: Current epoch number (0-indexed)
|
| 299 |
+
total_epochs: Total number of epochs
|
| 300 |
+
threshold: Transition center (0.5 = midpoint)
|
| 301 |
+
smoothness: Transition width
|
| 302 |
+
|
| 303 |
+
Returns:
|
| 304 |
+
Beta value in [0, 1]
|
| 305 |
+
"""
|
| 306 |
+
epoch_progress = tf.cast(current_epoch, tf.float32) / tf.cast(total_epochs, tf.float32)
|
| 307 |
+
beta = smooth_step(epoch_progress, threshold, smoothness)
|
| 308 |
+
return beta
|
| 309 |
+
|
| 310 |
+
###################### Loss Functions ######################
|
| 311 |
+
|
| 312 |
+
def unified_focal_loss(y_true, y_pred, gamma=2.0, alpha=None, exclude_class=None):
|
| 313 |
+
"""
|
| 314 |
+
Unified Focal Loss
|
| 315 |
+
|
| 316 |
+
Focal loss down-weights easy examples and focuses on hard examples.
|
| 317 |
+
Particularly effective for class imbalance and boundary regions.
|
| 318 |
+
|
| 319 |
+
Args:
|
| 320 |
+
y_true: Ground truth labels (bs, H, W, num_classes) one-hot encoded
|
| 321 |
+
y_pred: Predicted probabilities (bs, H, W, num_classes) from softmax
|
| 322 |
+
gamma: Focusing parameter (default 2.0)
|
| 323 |
+
- gamma=0: equivalent to cross-entropy
|
| 324 |
+
- gamma>0: down-weights easy examples
|
| 325 |
+
- Higher gamma = more focus on hard examples
|
| 326 |
+
alpha: Per-class balancing weights (num_classes,) - optional, trainable
|
| 327 |
+
- If None, no additional balancing
|
| 328 |
+
- If provided, applies per-class weighting like weighted CE
|
| 329 |
+
|
| 330 |
+
Returns:
|
| 331 |
+
Scalar loss value
|
| 332 |
+
|
| 333 |
+
Formula:
|
| 334 |
+
FL = -α * (1 - p_t)^γ * log(p_t)
|
| 335 |
+
where:
|
| 336 |
+
- p_t is probability of correct class
|
| 337 |
+
- (1 - p_t)^γ is modulating factor (focal term)
|
| 338 |
+
- α is class balancing weight
|
| 339 |
+
"""
|
| 340 |
+
# Clip predictions to avoid log(0)
|
| 341 |
+
y_pred = tf.clip_by_value(y_pred, 1e-7, 1.0 - 1e-7)
|
| 342 |
+
|
| 343 |
+
# Probability of correct class at each pixel
|
| 344 |
+
# y_true is one-hot, so this extracts p for the true class
|
| 345 |
+
p_t = tf.reduce_sum(y_true * y_pred, axis=-1)
|
| 346 |
+
# Shape: (bs, H, W)
|
| 347 |
+
|
| 348 |
+
# Focal term: (1 - p_t)^gamma
|
| 349 |
+
# This is small for easy examples (p_t ≈ 1) and large for hard examples (p_t ≈ 0)
|
| 350 |
+
focal_term = tf.pow(1.0 - p_t, gamma)
|
| 351 |
+
# Shape: (bs, H, W)
|
| 352 |
+
|
| 353 |
+
# Cross-entropy term: -log(p_t)
|
| 354 |
+
ce_term = -tf.math.log(p_t)
|
| 355 |
+
# Shape: (bs, H, W)
|
| 356 |
+
|
| 357 |
+
# Focal loss: focal_term * ce_term
|
| 358 |
+
focal_loss = focal_term * ce_term
|
| 359 |
+
# Shape: (bs, H, W)
|
| 360 |
+
|
| 361 |
+
# Optional: Apply alpha balancing (per-class weights)
|
| 362 |
+
if alpha is not None:
|
| 363 |
+
# Get weight for true class at each pixel
|
| 364 |
+
weights_tensor = tf.cast(alpha, dtype=tf.float32)
|
| 365 |
+
weights_tensor = tf.reshape(weights_tensor, [1, 1, 1, -1])
|
| 366 |
+
alpha_map = tf.reduce_sum(y_true * weights_tensor, axis=-1)
|
| 367 |
+
# Shape: (bs, H, W)
|
| 368 |
+
|
| 369 |
+
# Weighted focal
|
| 370 |
+
# Exclude specific class if specified
|
| 371 |
+
if exclude_class is not None:
|
| 372 |
+
class_mask = tf.argmax(y_true, axis=-1) # (bs, 256, 256)
|
| 373 |
+
valid_mask = tf.cast(class_mask != exclude_class, tf.float32)
|
| 374 |
+
|
| 375 |
+
if alpha is not None:
|
| 376 |
+
focal_loss = alpha_map * focal_loss * valid_mask
|
| 377 |
+
else:
|
| 378 |
+
focal_loss = focal_loss * valid_mask
|
| 379 |
+
|
| 380 |
+
return tf.reduce_sum(focal_loss) / (tf.reduce_sum(valid_mask) + 1e-7)
|
| 381 |
+
else:
|
| 382 |
+
|
| 383 |
+
if alpha is not None:
|
| 384 |
+
focal_loss = alpha_map * focal_loss
|
| 385 |
+
|
| 386 |
+
return tf.reduce_mean(focal_loss)
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def unified_focal_dice_loss(y_true, y_pred, gamma=0.5, delta=0.6, alpha=None, exclude_class=None):
|
| 390 |
+
"""
|
| 391 |
+
Unified Focal Loss - Dice-based
|
| 392 |
+
|
| 393 |
+
Combines Dice coefficient with precision-recall focal weighting.
|
| 394 |
+
Best for imbalanced multi-class segmentation with small structures.
|
| 395 |
+
|
| 396 |
+
Args:
|
| 397 |
+
y_true: Ground truth one-hot (bs, H, W, num_classes)
|
| 398 |
+
y_pred: Predicted probabilities (bs, H, W, num_classes)
|
| 399 |
+
gamma: Focusing parameter for Dice component (default 0.5)
|
| 400 |
+
- gamma=0: equivalent to Dice loss
|
| 401 |
+
- gamma>0: focuses on hard examples
|
| 402 |
+
delta: Weight for precision-recall component (0-1, default 0.6)
|
| 403 |
+
- Controls emphasis on boundary regions
|
| 404 |
+
alpha: Per-class weights (num_classes,) - optional
|
| 405 |
+
exclude_class: Class index to exclude from loss
|
| 406 |
+
|
| 407 |
+
Returns:
|
| 408 |
+
Scalar loss value
|
| 409 |
+
|
| 410 |
+
Formula:
|
| 411 |
+
UFL = (1 - Dice)^gamma * (1 - precision * recall)^delta
|
| 412 |
+
Focuses on hard examples and boundary regions
|
| 413 |
+
"""
|
| 414 |
+
smooth = 1e-6
|
| 415 |
+
y_pred = tf.clip_by_value(y_pred, 1e-7, 1.0 - 1e-7)
|
| 416 |
+
num_classes = tf.shape(y_pred)[-1]
|
| 417 |
+
|
| 418 |
+
unified_losses = []
|
| 419 |
+
|
| 420 |
+
for class_idx in range(num_classes if isinstance(num_classes, int) else y_pred.shape[-1]):
|
| 421 |
+
# Skip excluded class
|
| 422 |
+
if exclude_class is not None and class_idx == exclude_class:
|
| 423 |
+
continue
|
| 424 |
+
|
| 425 |
+
y_true_class = y_true[..., class_idx]
|
| 426 |
+
y_pred_class = y_pred[..., class_idx]
|
| 427 |
+
|
| 428 |
+
# Flatten for calculations
|
| 429 |
+
y_true_f = tf.reshape(y_true_class, [-1])
|
| 430 |
+
y_pred_f = tf.reshape(y_pred_class, [-1])
|
| 431 |
+
|
| 432 |
+
# True positives, false positives, false negatives
|
| 433 |
+
tp = tf.reduce_sum(y_true_f * y_pred_f)
|
| 434 |
+
fp = tf.reduce_sum((1.0 - y_true_f) * y_pred_f)
|
| 435 |
+
fn = tf.reduce_sum(y_true_f * (1.0 - y_pred_f))
|
| 436 |
+
|
| 437 |
+
# Precision and recall
|
| 438 |
+
precision = (tp + smooth) / (tp + fp + smooth)
|
| 439 |
+
recall = (tp + smooth) / (tp + fn + smooth)
|
| 440 |
+
|
| 441 |
+
# Dice coefficient
|
| 442 |
+
dice = (2.0 * tp + smooth) / (2.0 * tp + fp + fn + smooth)
|
| 443 |
+
|
| 444 |
+
# Unified focal loss: focuses on hard examples and boundary regions
|
| 445 |
+
# (1 - dice)^gamma: focuses on classes with low Dice (hard examples)
|
| 446 |
+
# (1 - precision * recall)^delta: focuses on boundary regions
|
| 447 |
+
unified_loss_class = tf.pow(1.0 - dice, gamma) * tf.pow(1.0 - precision * recall, delta)
|
| 448 |
+
|
| 449 |
+
# Apply class weights
|
| 450 |
+
if alpha is not None:
|
| 451 |
+
unified_loss_class = unified_loss_class * tf.cast(alpha[class_idx], tf.float32)
|
| 452 |
+
|
| 453 |
+
unified_losses.append(unified_loss_class)
|
| 454 |
+
|
| 455 |
+
# Stack and mean across classes (excluding the skipped class)
|
| 456 |
+
total_loss = tf.reduce_mean(tf.stack(unified_losses))
|
| 457 |
+
|
| 458 |
+
return total_loss
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
def weighted_categorical_crossentropy(y_true, y_pred, class_weights, exclude_class=None):
|
| 462 |
+
"""
|
| 463 |
+
Weighted categorical cross-entropy loss
|
| 464 |
+
|
| 465 |
+
Args:
|
| 466 |
+
y_true: (bs, 256, 256, num_classes) one-hot encoded
|
| 467 |
+
y_pred: (bs, 256, 256, num_classes) softmax probabilities
|
| 468 |
+
class_weights: (num_classes,) weight per class
|
| 469 |
+
exclude_class: Optional int, class index to exclude from loss (e.g., 2 for CSF)
|
| 470 |
+
|
| 471 |
+
Returns:
|
| 472 |
+
Scalar loss value
|
| 473 |
+
"""
|
| 474 |
+
# Clip predictions to prevent log(0)
|
| 475 |
+
y_pred = tf.clip_by_value(y_pred, 1e-7, 1.0 - 1e-7)
|
| 476 |
+
|
| 477 |
+
# Cross-entropy per pixel: -sum(y_true * log(y_pred))
|
| 478 |
+
ce = -tf.reduce_sum(y_true * tf.math.log(y_pred), axis=-1) # (bs, 256, 256)
|
| 479 |
+
|
| 480 |
+
# Apply class weights
|
| 481 |
+
# class_weights shape: (num_classes,) -> (1, 1, 1, num_classes) for broadcasting
|
| 482 |
+
weights_tensor = tf.cast(class_weights, dtype=tf.float32)
|
| 483 |
+
weights_tensor = tf.reshape(weights_tensor, [1, 1, 1, -1])
|
| 484 |
+
|
| 485 |
+
# Weight map: (bs, 256, 256)
|
| 486 |
+
pixel_weights = tf.reduce_sum(y_true * weights_tensor, axis=-1)
|
| 487 |
+
|
| 488 |
+
# Weighted cross-entropy
|
| 489 |
+
# Exclude specific class if specified
|
| 490 |
+
if exclude_class is not None:
|
| 491 |
+
class_mask = tf.argmax(y_true, axis=-1) # (bs, 256, 256)
|
| 492 |
+
valid_mask = tf.cast(class_mask != exclude_class, tf.float32)
|
| 493 |
+
weighted_ce = ce * pixel_weights * valid_mask
|
| 494 |
+
return tf.reduce_sum(weighted_ce) / (tf.reduce_sum(valid_mask) + 1e-7)
|
| 495 |
+
else:
|
| 496 |
+
weighted_ce = ce * pixel_weights
|
| 497 |
+
return tf.reduce_mean(weighted_ce)
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
def adaptive_segmentation_loss(y_true, y_pred, class_weights, beta,
|
| 501 |
+
focal_gamma=0.5, use_focal_alpha=True,
|
| 502 |
+
exclude_class=None):
|
| 503 |
+
"""
|
| 504 |
+
Adaptive segmentation loss with hard phase transition
|
| 505 |
+
|
| 506 |
+
Combines weighted cross-entropy (phase 1) and focal loss (phase 2)
|
| 507 |
+
based on epoch progress (beta).
|
| 508 |
+
|
| 509 |
+
Args:
|
| 510 |
+
y_true: Ground truth (bs, H, W, num_classes) one-hot
|
| 511 |
+
y_pred: Predictions (bs, H, W, num_classes) softmax probabilities
|
| 512 |
+
class_weights: Trainable class weights (num_classes,)
|
| 513 |
+
beta: Transition parameter [0, 1]
|
| 514 |
+
- beta=0: pure weighted CE (early training)
|
| 515 |
+
- beta=1: pure focal loss (late training)
|
| 516 |
+
focal_gamma: Focusing parameter for focal loss (default 0.5)
|
| 517 |
+
use_focal_alpha: Whether to use class_weights as focal alpha
|
| 518 |
+
|
| 519 |
+
Returns:
|
| 520 |
+
seg_loss: Final loss
|
| 521 |
+
wcce_loss: Weighted CE component (for monitoring)
|
| 522 |
+
focal_loss: Focal loss component (for monitoring)
|
| 523 |
+
|
| 524 |
+
Phase Behavior:
|
| 525 |
+
Epochs 1-10: beta ≈ 0 → Weighted CE dominates
|
| 526 |
+
- Learns basic class separation
|
| 527 |
+
- Benefits from explicit class weighting
|
| 528 |
+
|
| 529 |
+
Epochs 10-20: beta transitions 0 → 1
|
| 530 |
+
- Smooth change in loss landscape
|
| 531 |
+
- Gradual shift in training dynamics
|
| 532 |
+
|
| 533 |
+
Epochs 20-60: beta ≈ 1 → Focal loss dominates
|
| 534 |
+
- Focuses on hard examples
|
| 535 |
+
- Refines boundaries and difficult regions
|
| 536 |
+
"""
|
| 537 |
+
# Compute Phase 1 loss: Weighted Cross-Entropy
|
| 538 |
+
wcce_loss = 10 * weighted_categorical_crossentropy(y_true, y_pred, class_weights, exclude_class=exclude_class)
|
| 539 |
+
|
| 540 |
+
# Compute Phase 2 loss: Focal Loss
|
| 541 |
+
focal_alpha = class_weights if use_focal_alpha else None
|
| 542 |
+
focal_loss = unified_focal_dice_loss(y_true, y_pred,
|
| 543 |
+
gamma=focal_gamma,
|
| 544 |
+
alpha=focal_alpha,
|
| 545 |
+
exclude_class=exclude_class)
|
| 546 |
+
|
| 547 |
+
# Adaptive combination based on beta
|
| 548 |
+
# beta=0: (1-0)*wce + 0*focal = wce (phase 1)
|
| 549 |
+
# beta=1: (1-1)*wce + 1*focal = focal (phase 2)
|
| 550 |
+
# beta=0.5: 0.5*wce + 0.5*focal = equal mix (transition)
|
| 551 |
+
seg_loss = (1.0 - beta) * wcce_loss + beta * focal_loss
|
| 552 |
+
|
| 553 |
+
return seg_loss, wcce_loss, focal_loss
|
| 554 |
+
|
| 555 |
+
###################### Training Functions ######################
|
| 556 |
+
|
| 557 |
+
@tf.function
|
| 558 |
+
def train_step(input_image, target_onehot, model, optimizer,
|
| 559 |
+
class_weights, beta, focal_gamma,
|
| 560 |
+
use_focal_alpha=True, exclude_class=None):
|
| 561 |
+
"""
|
| 562 |
+
Single training step for U-Net
|
| 563 |
+
|
| 564 |
+
Args:
|
| 565 |
+
input_image: Input FLAIR (bs, 256, 256, 1) in [-1, 1]
|
| 566 |
+
target_onehot: Target mask (bs, 256, 256, num_classes) one-hot
|
| 567 |
+
model: a specific U-Net model
|
| 568 |
+
optimizer: Optimizer
|
| 569 |
+
class_weights: (num_classes,) weight per class
|
| 570 |
+
beta: Current beta for phase transition
|
| 571 |
+
|
| 572 |
+
|
| 573 |
+
Returns:
|
| 574 |
+
loss: Training loss value
|
| 575 |
+
"""
|
| 576 |
+
with tf.GradientTape() as tape:
|
| 577 |
+
# Forward pass
|
| 578 |
+
predictions = model(input_image, training=True)
|
| 579 |
+
|
| 580 |
+
# Compute loss
|
| 581 |
+
seg_loss, wcce_loss, focal_loss = adaptive_segmentation_loss(target_onehot, predictions, class_weights,
|
| 582 |
+
beta, focal_gamma, use_focal_alpha, exclude_class)
|
| 583 |
+
|
| 584 |
+
# Calculate gradients
|
| 585 |
+
gradients = tape.gradient(seg_loss, model.trainable_variables)
|
| 586 |
+
|
| 587 |
+
# Apply gradients
|
| 588 |
+
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
|
| 589 |
+
|
| 590 |
+
return seg_loss, wcce_loss, focal_loss
|
| 591 |
+
|
| 592 |
+
def generate_and_save_images(model, test_input, test_target,
|
| 593 |
+
epoch, save_path, num_classes):
|
| 594 |
+
"""
|
| 595 |
+
Generate predictions and save visualization
|
| 596 |
+
|
| 597 |
+
Args:
|
| 598 |
+
model: a specific U-Net model
|
| 599 |
+
test_input: Test input image (bs, 256, 512, 1)
|
| 600 |
+
test_target: Test target mask (bs, 256, 256)
|
| 601 |
+
epoch: Current epoch number
|
| 602 |
+
save_path: Path to save figure
|
| 603 |
+
num_classes: Number of classes
|
| 604 |
+
"""
|
| 605 |
+
for ik in range(test_input.numpy().shape[0]):
|
| 606 |
+
# Extract FLAIR
|
| 607 |
+
flair_normalized = test_input[ik, :, :256, :]
|
| 608 |
+
flair_normalized = tf.expand_dims(flair_normalized, axis=0)
|
| 609 |
+
|
| 610 |
+
# Generate prediction
|
| 611 |
+
prediction_softmax = model(flair_normalized, training=False)
|
| 612 |
+
|
| 613 |
+
# Convert to class labels
|
| 614 |
+
pred_classes = tf.argmax(prediction_softmax, axis=-1).numpy()
|
| 615 |
+
target_mask = test_target[ik].numpy()
|
| 616 |
+
|
| 617 |
+
# Create figure
|
| 618 |
+
plt.figure(figsize=(20, 5))
|
| 619 |
+
|
| 620 |
+
# Input FLAIR
|
| 621 |
+
plt.subplot(1, 5, 1)
|
| 622 |
+
plt.title('Input FLAIR')
|
| 623 |
+
plt.imshow(flair_normalized[0, :, :, 0], cmap='gray')
|
| 624 |
+
plt.axis('off')
|
| 625 |
+
|
| 626 |
+
# Ground truth
|
| 627 |
+
plt.subplot(1, 5, 2)
|
| 628 |
+
plt.title('Ground Truth')
|
| 629 |
+
plt.imshow(target_mask, cmap='jet', vmin=0, vmax=num_classes-1)
|
| 630 |
+
plt.colorbar()
|
| 631 |
+
plt.axis('off')
|
| 632 |
+
|
| 633 |
+
# Prediction
|
| 634 |
+
plt.subplot(1, 5, 3)
|
| 635 |
+
plt.title('Predicted Classes')
|
| 636 |
+
plt.imshow(pred_classes[0], cmap='jet', vmin=0, vmax=num_classes-1)
|
| 637 |
+
plt.colorbar()
|
| 638 |
+
plt.axis('off')
|
| 639 |
+
|
| 640 |
+
# Class probabilities for most confident prediction
|
| 641 |
+
plt.subplot(1, 5, 4)
|
| 642 |
+
plt.title('Max Probability')
|
| 643 |
+
max_prob = tf.reduce_max(prediction_softmax[0], axis=-1).numpy()
|
| 644 |
+
plt.imshow(max_prob, cmap='viridis', vmin=0, vmax=1)
|
| 645 |
+
plt.colorbar()
|
| 646 |
+
plt.axis('off')
|
| 647 |
+
|
| 648 |
+
# Difference map
|
| 649 |
+
plt.subplot(1, 5, 5)
|
| 650 |
+
plt.title('Error Map (Red=Wrong)')
|
| 651 |
+
error_map = (pred_classes[0] != target_mask).astype(float)
|
| 652 |
+
plt.imshow(error_map, cmap='Reds', vmin=0, vmax=1)
|
| 653 |
+
plt.colorbar()
|
| 654 |
+
plt.axis('off')
|
| 655 |
+
|
| 656 |
+
plt.tight_layout()
|
| 657 |
+
plt.savefig(save_path / f'epoch_{epoch:03d}_{ik+1}.png', dpi=300, bbox_inches='tight')
|
| 658 |
+
plt.close()
|
| 659 |
+
|
| 660 |
+
###################### Main Training Function ######################
|
| 661 |
+
|
| 662 |
+
def train_net(config: ExperimentConfig):
|
| 663 |
+
"""
|
| 664 |
+
Main training function for a Specific U-Net
|
| 665 |
+
|
| 666 |
+
Args:
|
| 667 |
+
config: ExperimentConfig object
|
| 668 |
+
"""
|
| 669 |
+
print("\n" + "="*70)
|
| 670 |
+
print(f"TRAINING {config.architecture_name}: {config.exp_name}")
|
| 671 |
+
print("="*70)
|
| 672 |
+
print(f"Variant: {config.variant}")
|
| 673 |
+
print(f"Preprocessing: {config.preprocessing}")
|
| 674 |
+
print(f"Class scenario: {config.class_scenario} ({config.num_classes} classes)")
|
| 675 |
+
print(f"Fold: {config.fold_id}")
|
| 676 |
+
print(f"Epochs: {config.epochs}")
|
| 677 |
+
print(f"Batch size: {config.batch_size}")
|
| 678 |
+
print(f"Loss: Weighted Categorical Cross-Entropy → Unified Focal")
|
| 679 |
+
print("="*70 + "\n")
|
| 680 |
+
|
| 681 |
+
# Check initial GPU memory
|
| 682 |
+
get_gpu_memory_info()
|
| 683 |
+
|
| 684 |
+
# Initialize data loader
|
| 685 |
+
data_config = DataConfig()
|
| 686 |
+
data_loader = P2DataLoader(data_config)
|
| 687 |
+
|
| 688 |
+
# Load datasets
|
| 689 |
+
print("Loading training data...")
|
| 690 |
+
train_dataset = data_loader.create_dataset_for_fold(
|
| 691 |
+
fold_id=config.fold_id,
|
| 692 |
+
split='train',
|
| 693 |
+
preprocessing=config.preprocessing,
|
| 694 |
+
class_scenario=config.class_scenario,
|
| 695 |
+
batch_size=config.batch_size,
|
| 696 |
+
shuffle=True
|
| 697 |
+
)
|
| 698 |
+
|
| 699 |
+
print("Loading validation data...")
|
| 700 |
+
val_dataset = data_loader.create_dataset_for_fold(
|
| 701 |
+
fold_id=config.fold_id,
|
| 702 |
+
split='val',
|
| 703 |
+
preprocessing=config.preprocessing,
|
| 704 |
+
class_scenario=config.class_scenario,
|
| 705 |
+
batch_size=config.batch_size,
|
| 706 |
+
shuffle=False
|
| 707 |
+
)
|
| 708 |
+
|
| 709 |
+
# Get dataset sizes
|
| 710 |
+
# Note: from_generator pipelines always report cardinality as INFINITE (-1)
|
| 711 |
+
# even with .cache(), so we derive the batch count from the slice list instead.
|
| 712 |
+
# We iterate once here; this also warms the in-memory cache so epoch 1 is fast.
|
| 713 |
+
print("Warming dataset cache (first pass over data — subsequent epochs use RAM)...")
|
| 714 |
+
train_size = sum(1 for _ in train_dataset)
|
| 715 |
+
val_size = sum(1 for _ in val_dataset)
|
| 716 |
+
# ⚠️ Do NOT rebuild the datasets here — that would create new generators and
|
| 717 |
+
# throw away the cache we just populated.
|
| 718 |
+
|
| 719 |
+
print(f"Training samples (batches): {train_size}")
|
| 720 |
+
print(f"Validation samples (batches): {val_size}\n")
|
| 721 |
+
|
| 722 |
+
# Compute or load class weights
|
| 723 |
+
print("Computing class weights from training data...")
|
| 724 |
+
try:
|
| 725 |
+
class_weights = load_class_weights(
|
| 726 |
+
config.fold_id, config.class_scenario,
|
| 727 |
+
config.preprocessing, config.weights_dir
|
| 728 |
+
)
|
| 729 |
+
print("✅ Loaded pre-computed class weights")
|
| 730 |
+
except FileNotFoundError:
|
| 731 |
+
print("Computing class weights (this may take a few minutes)...")
|
| 732 |
+
results = compute_and_save_class_weights(
|
| 733 |
+
config.fold_id, config.class_scenario,
|
| 734 |
+
config.preprocessing, str(config.weights_dir)
|
| 735 |
+
)
|
| 736 |
+
class_weights = np.array(results['class_weights'], dtype=np.float32)
|
| 737 |
+
|
| 738 |
+
print(f"Class weights: {class_weights}")
|
| 739 |
+
|
| 740 |
+
# Build model
|
| 741 |
+
print(f"\n🏗️ Building {config.architecture_name} model...")
|
| 742 |
+
|
| 743 |
+
if config.architecture_name == 'unet':
|
| 744 |
+
from unet_model import build_unet_3class as build_specific_3class # must be updated with the actual used model for traininig
|
| 745 |
+
elif config.architecture_name == 'attnunet':
|
| 746 |
+
from attn_unet_model import build_attention_unet_3class as build_specific_3class
|
| 747 |
+
elif config.architecture_name == 'dlv3unet':
|
| 748 |
+
from dlv3_unet_model_GN import build_deeplabv3_unet_3class as build_specific_3class
|
| 749 |
+
elif config.architecture_name == 'transunet':
|
| 750 |
+
from trans_unet_model import build_trans_unet_3class as build_specific_3class
|
| 751 |
+
else:
|
| 752 |
+
print(f"❌ Error loading model: Invalid Model Name")
|
| 753 |
+
raise
|
| 754 |
+
|
| 755 |
+
model = build_specific_3class(input_shape=(256, 256, 1), num_classes=config.num_classes)
|
| 756 |
+
|
| 757 |
+
print(f"Model parameters: {model.count_params():,}\n")
|
| 758 |
+
|
| 759 |
+
# Optimizer (will be updated with ReduceLROnPlateau)
|
| 760 |
+
optimizer = tf.keras.optimizers.legacy.Adam(
|
| 761 |
+
config.learning_rate, beta_1=config.beta_1
|
| 762 |
+
)
|
| 763 |
+
|
| 764 |
+
# Initialize optimizer variables
|
| 765 |
+
print("Initializing optimizer variables...")
|
| 766 |
+
dummy_input = tf.zeros((1, 256, 256, 1))
|
| 767 |
+
|
| 768 |
+
with tf.GradientTape() as tape:
|
| 769 |
+
output = model(dummy_input, training=True)
|
| 770 |
+
dummy_loss = tf.reduce_mean(output)
|
| 771 |
+
|
| 772 |
+
# Apply dummy gradients to build optimizer variables
|
| 773 |
+
grads = tape.gradient(dummy_loss, model.trainable_variables)
|
| 774 |
+
optimizer.apply_gradients(zip(grads, model.trainable_variables))
|
| 775 |
+
print("✅ Optimizer variables initialized\n")
|
| 776 |
+
|
| 777 |
+
# Checkpoint
|
| 778 |
+
checkpoint = tf.train.Checkpoint(
|
| 779 |
+
optimizer=optimizer,
|
| 780 |
+
model=model
|
| 781 |
+
)
|
| 782 |
+
|
| 783 |
+
checkpoint_prefix = config.checkpoint_dir / "ckpt"
|
| 784 |
+
manager = tf.train.CheckpointManager(
|
| 785 |
+
checkpoint, config.checkpoint_dir, max_to_keep=1
|
| 786 |
+
)
|
| 787 |
+
|
| 788 |
+
if manager.latest_checkpoint:
|
| 789 |
+
checkpoint.restore(manager.latest_checkpoint)
|
| 790 |
+
print(f"✅ Restored from checkpoint: {manager.latest_checkpoint}\n")
|
| 791 |
+
else:
|
| 792 |
+
print("Starting training from scratch\n")
|
| 793 |
+
|
| 794 |
+
# Get example for visualization
|
| 795 |
+
skip_n = 1 # min(100 // config.batch_size, val_size - 1)
|
| 796 |
+
example_paired, example_target, _, _ = next(iter(val_dataset.skip(skip_n).take(20)))
|
| 797 |
+
|
| 798 |
+
print("Initializing metrics computer...")
|
| 799 |
+
if config.num_classes == 4:
|
| 800 |
+
class_names = ['Background', 'Ventricles', 'Normal_WMH', 'Abnormal_WMH']
|
| 801 |
+
elif config.num_classes == 3:
|
| 802 |
+
class_names = ['Background', 'Ventricles', 'Abnormal_WMH']
|
| 803 |
+
|
| 804 |
+
# Training history
|
| 805 |
+
history = {
|
| 806 |
+
'train_loss': [],
|
| 807 |
+
'wce_loss': [],
|
| 808 |
+
'ufd_loss': [],
|
| 809 |
+
'val_loss': [],
|
| 810 |
+
'val_loss_wce': [],
|
| 811 |
+
'val_loss_ufd': [],
|
| 812 |
+
'val_metrics': [],
|
| 813 |
+
'beta_value': []
|
| 814 |
+
}
|
| 815 |
+
|
| 816 |
+
# Training loop
|
| 817 |
+
best_val_loss = float('inf')
|
| 818 |
+
best_val_dice = float('-inf')
|
| 819 |
+
exclude_class = 2 if config.num_classes == 4 else None # Exclude class 2 only in 4-class
|
| 820 |
+
|
| 821 |
+
try:
|
| 822 |
+
for epoch in range(config.epochs):
|
| 823 |
+
start_time = time.time()
|
| 824 |
+
|
| 825 |
+
# Compute beta for this epoch
|
| 826 |
+
beta_value = compute_beta_schedule(
|
| 827 |
+
epoch, config.epochs,
|
| 828 |
+
config.beta_threshold, config.beta_smoothness
|
| 829 |
+
)
|
| 830 |
+
|
| 831 |
+
# Training metrics
|
| 832 |
+
epoch_losses = []
|
| 833 |
+
epoch_loss_wce = []
|
| 834 |
+
epoch_loss_ufd = []
|
| 835 |
+
|
| 836 |
+
# Training loop
|
| 837 |
+
|
| 838 |
+
# Update learning rate based on epoch
|
| 839 |
+
|
| 840 |
+
# y1 = 2 * np.exp(-np.log(400) * x) # original
|
| 841 |
+
# y2 = 2 * np.exp(-np.log(400) * x**2) # milder
|
| 842 |
+
# y3 = 2 * np.exp(-np.log(400) * x**3) # even milder ✅
|
| 843 |
+
# y4 = 2 * np.exp(-np.log(400) * x**5) # very mild
|
| 844 |
+
|
| 845 |
+
new_lr = config.learning_rate * np.exp(-np.log(400) * (epoch / config.epochs)**3) # Steadily and exponentially decay from 2e-4 to 5e-7
|
| 846 |
+
optimizer.learning_rate.assign(new_lr)
|
| 847 |
+
|
| 848 |
+
print(f"\nEpoch {epoch+1}/{config.epochs} (β={beta_value.numpy():.4f}) (lr={new_lr*10000:.3f} 10-4)")
|
| 849 |
+
train_bar = tqdm(train_dataset, total=train_size, desc="Training")
|
| 850 |
+
|
| 851 |
+
for paired_input, target_mask, patient_id_tensor, slice_num_tensor in train_bar:
|
| 852 |
+
|
| 853 |
+
patient_id = patient_id_tensor.numpy()[0].decode('utf-8') # batch dim + bytes→str
|
| 854 |
+
slice_num = int(slice_num_tensor.numpy()[0])
|
| 855 |
+
|
| 856 |
+
# ✅ Prepare inputs: normalize FLAIR + one-hot encode target
|
| 857 |
+
flair_normalized, target_onehot = prepare_inputs(
|
| 858 |
+
paired_input, target_mask, config.num_classes
|
| 859 |
+
)
|
| 860 |
+
|
| 861 |
+
# Train step
|
| 862 |
+
loss, wce_loss, ufd_loss = train_step(
|
| 863 |
+
flair_normalized, target_onehot,
|
| 864 |
+
model, optimizer, class_weights,
|
| 865 |
+
beta_value, config.focal_gamma
|
| 866 |
+
)
|
| 867 |
+
|
| 868 |
+
epoch_losses.append(loss.numpy())
|
| 869 |
+
epoch_loss_wce.append(wce_loss.numpy())
|
| 870 |
+
epoch_loss_ufd.append(ufd_loss.numpy())
|
| 871 |
+
|
| 872 |
+
# Update progress bar
|
| 873 |
+
train_bar.set_postfix({
|
| 874 |
+
'seg_loss': f"{loss.numpy():.5f}",
|
| 875 |
+
'wce_loss': f"{wce_loss.numpy():.5f}",
|
| 876 |
+
'ufd_loss': f"{ufd_loss.numpy():.5f}",
|
| 877 |
+
})
|
| 878 |
+
|
| 879 |
+
# Calculate epoch average
|
| 880 |
+
avg_train_loss = np.mean(epoch_losses)
|
| 881 |
+
avg_train_loss_wce = np.mean(epoch_loss_wce)
|
| 882 |
+
avg_train_loss_ufd = np.mean(epoch_loss_ufd)
|
| 883 |
+
|
| 884 |
+
history['train_loss'].append(avg_train_loss)
|
| 885 |
+
history['wce_loss'].append(avg_train_loss_wce)
|
| 886 |
+
history['ufd_loss'].append(avg_train_loss_ufd)
|
| 887 |
+
history['beta_value'].append(float(beta_value.numpy()))
|
| 888 |
+
|
| 889 |
+
# Validation
|
| 890 |
+
val_losses = []
|
| 891 |
+
val_losses_wce = []
|
| 892 |
+
val_losses_ufd = []
|
| 893 |
+
all_val_true = []
|
| 894 |
+
all_val_pred = []
|
| 895 |
+
|
| 896 |
+
for val_paired, val_target, patient_id_tensor, slice_num_tensor in val_dataset:
|
| 897 |
+
try:
|
| 898 |
+
|
| 899 |
+
patient_id = patient_id_tensor.numpy()[0].decode('utf-8') # batch dim + bytes→str
|
| 900 |
+
slice_num = int(slice_num_tensor.numpy()[0])
|
| 901 |
+
|
| 902 |
+
val_flair_norm, val_target_onehot = prepare_inputs(
|
| 903 |
+
val_paired, val_target, config.num_classes
|
| 904 |
+
)
|
| 905 |
+
|
| 906 |
+
val_pred = model(val_flair_norm, training=False)
|
| 907 |
+
|
| 908 |
+
val_loss, val_wce_loss, val_ufd_loss = adaptive_segmentation_loss(
|
| 909 |
+
val_target_onehot, val_pred, class_weights,
|
| 910 |
+
beta_value, focal_gamma=config.focal_gamma, exclude_class=exclude_class
|
| 911 |
+
)
|
| 912 |
+
|
| 913 |
+
# Store true and prediction values for metrics calculation
|
| 914 |
+
all_val_true.append(val_target_onehot)
|
| 915 |
+
all_val_pred.append(val_pred)
|
| 916 |
+
|
| 917 |
+
if not tf.math.is_nan(val_loss):
|
| 918 |
+
val_losses.append(val_loss.numpy())
|
| 919 |
+
val_losses_wce.append(val_wce_loss.numpy())
|
| 920 |
+
val_losses_ufd.append(val_ufd_loss.numpy())
|
| 921 |
+
except:
|
| 922 |
+
continue
|
| 923 |
+
|
| 924 |
+
if len(val_losses) > 0:
|
| 925 |
+
avg_val_loss = np.mean(val_losses)
|
| 926 |
+
avg_val_loss_wce = np.mean(val_losses_wce)
|
| 927 |
+
avg_val_loss_ufd = np.mean(val_losses_ufd)
|
| 928 |
+
|
| 929 |
+
history['val_loss'].append(avg_val_loss)
|
| 930 |
+
history['val_loss_wce'].append(avg_val_loss_wce)
|
| 931 |
+
history['val_loss_ufd'].append(avg_val_loss_ufd)
|
| 932 |
+
|
| 933 |
+
# Compute class-wise metrics
|
| 934 |
+
val_metrics = compute_classwise_metrics(
|
| 935 |
+
all_val_true, all_val_pred,
|
| 936 |
+
config.num_classes#, exclude_class=exclude_class
|
| 937 |
+
)
|
| 938 |
+
history['val_metrics'].append(val_metrics)
|
| 939 |
+
|
| 940 |
+
# Print validation results
|
| 941 |
+
epoch_time = time.time() - start_time
|
| 942 |
+
print(f"\n{'='*70}")
|
| 943 |
+
print(f"Epoch {epoch+1}/{config.epochs} Summary (Time: {epoch_time:.2f}s)")
|
| 944 |
+
print(f"{'='*70}")
|
| 945 |
+
print(f"Training Loss: {avg_train_loss:.4f} | wce_loss: {avg_train_loss_wce:.4f} | ufd_loss: {avg_train_loss_ufd:.4f}")
|
| 946 |
+
print(f"Validation Loss: {avg_val_loss:.4f}")
|
| 947 |
+
print(f"\nClass-wise Dice Scores:")
|
| 948 |
+
for class_name, dice_val in val_metrics['dice'].items():
|
| 949 |
+
if class_name != 'mean':
|
| 950 |
+
print(f" {class_name}: {dice_val:.4f}")
|
| 951 |
+
if class_name == f"class_{config.num_classes - 1}":
|
| 952 |
+
abwmh_val_dice = dice_val
|
| 953 |
+
elif class_name == f"class_1":
|
| 954 |
+
vent_val_dice = dice_val
|
| 955 |
+
print(f" Mean Dice: {val_metrics['dice']['mean']:.4f}")
|
| 956 |
+
print(f"\nClass-wise Precision:")
|
| 957 |
+
for class_name, prec_val in val_metrics['precision'].items():
|
| 958 |
+
if class_name != 'mean':
|
| 959 |
+
print(f" {class_name}: {prec_val:.4f}")
|
| 960 |
+
print(f" Mean Precision: {val_metrics['precision']['mean']:.4f}")
|
| 961 |
+
print(f"\nClass-wise Recall:")
|
| 962 |
+
for class_name, rec_val in val_metrics['recall'].items():
|
| 963 |
+
if class_name != 'mean':
|
| 964 |
+
print(f" {class_name}: {rec_val:.4f}")
|
| 965 |
+
print(f" Mean Recall: {val_metrics['recall']['mean']:.4f}")
|
| 966 |
+
print(f"{'='*70}\n")
|
| 967 |
+
|
| 968 |
+
# Save best model based on overall validation performance
|
| 969 |
+
overal_val_performance = 0.6 * abwmh_val_dice + 0.3 * vent_val_dice + 0.1 * (1 - 1*avg_val_loss)
|
| 970 |
+
if overal_val_performance > best_val_dice and beta_value.numpy() > 0.9:
|
| 971 |
+
best_val_dice = overal_val_performance
|
| 972 |
+
model.save_weights(f"{config.checkpoint_dir}/best_dice_model.h5")
|
| 973 |
+
print(f"✓ Best model saved (performance: {best_val_dice:.4f})")
|
| 974 |
+
else:
|
| 975 |
+
print("Warning: No valid validation batches")
|
| 976 |
+
history['val_loss'].append(float('nan'))
|
| 977 |
+
history['val_metrics'].append({})
|
| 978 |
+
|
| 979 |
+
# Save checkpoint
|
| 980 |
+
if (epoch + 1) % 5 == 0 and False:
|
| 981 |
+
manager.save()
|
| 982 |
+
print(f" 💾 Saved checkpoint")
|
| 983 |
+
|
| 984 |
+
# Generate sample images
|
| 985 |
+
if ((epoch + 1) % 5 == 0 or epoch == 0) or True:
|
| 986 |
+
generate_and_save_images(
|
| 987 |
+
model, example_paired, example_target,
|
| 988 |
+
epoch + 1, config.figures_dir, config.num_classes
|
| 989 |
+
)
|
| 990 |
+
print(f" 📊 Saved visualization")
|
| 991 |
+
|
| 992 |
+
# # Save final model
|
| 993 |
+
# final_model_path = config.checkpoint_dir / "final_model.h5"
|
| 994 |
+
# model.save(final_model_path)
|
| 995 |
+
# print(f"\n✅ Training complete! Final model saved to {final_model_path}")
|
| 996 |
+
|
| 997 |
+
# Save history
|
| 998 |
+
history_serializable = {
|
| 999 |
+
key: [float(val) if isinstance(val, (int, float, np.number)) else val
|
| 1000 |
+
for val in values]
|
| 1001 |
+
for key, values in history.items()
|
| 1002 |
+
}
|
| 1003 |
+
|
| 1004 |
+
history_file = config.checkpoint_dir / "history.json"
|
| 1005 |
+
with open(history_file, 'w') as f:
|
| 1006 |
+
json.dump(history_serializable, f, indent=2)
|
| 1007 |
+
|
| 1008 |
+
return history, history_file
|
| 1009 |
+
|
| 1010 |
+
finally:
|
| 1011 |
+
# CRITICAL: Always cleanup, even if training fails
|
| 1012 |
+
print("\n🧹 Cleaning up resources...")
|
| 1013 |
+
|
| 1014 |
+
# Delete models explicitly to break references
|
| 1015 |
+
try:
|
| 1016 |
+
del model
|
| 1017 |
+
del optimizer
|
| 1018 |
+
del checkpoint
|
| 1019 |
+
del manager
|
| 1020 |
+
del train_dataset
|
| 1021 |
+
del val_dataset
|
| 1022 |
+
print("✅ Deleted model objects")
|
| 1023 |
+
except Exception as e:
|
| 1024 |
+
print(f"⚠️ Error deleting objects: {e}")
|
| 1025 |
+
|
| 1026 |
+
# Clear GPU memory
|
| 1027 |
+
clear_gpu_memory()
|
| 1028 |
+
|
| 1029 |
+
# Check final GPU memory
|
| 1030 |
+
get_gpu_memory_info()
|
| 1031 |
+
|
| 1032 |
+
###################### Main Execution ######################
|
| 1033 |
+
|
| 1034 |
+
if __name__ == "__main__":
|
| 1035 |
+
|
| 1036 |
+
# Example: Train a specific U-Net for 3-class, standard preprocessing, fold 0
|
| 1037 |
+
|
| 1038 |
+
config = ExperimentConfig(
|
| 1039 |
+
variant=3,
|
| 1040 |
+
preprocessing='standard',
|
| 1041 |
+
class_scenario='3class',
|
| 1042 |
+
fold_id=0,
|
| 1043 |
+
architecture_name='dlv3unet' # ['unet', 'attnunet', 'dlv3unet', transunet']
|
| 1044 |
+
)
|
| 1045 |
+
|
| 1046 |
+
history, history_path = train_net(config)
|
| 1047 |
+
|
| 1048 |
+
print("\n" + "="*70)
|
| 1049 |
+
print("U-NET TRAINING COMPLETE")
|
| 1050 |
+
print("="*70)
|
| 1051 |
+
|
models/for_WMH_Vent/model_training_scripts/trans_unet_model.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
###################### Libraries ######################
|
| 2 |
+
# Deep Learning
|
| 3 |
+
import tensorflow as tf
|
| 4 |
+
import keras
|
| 5 |
+
from keras.models import Model, load_model
|
| 6 |
+
from keras.layers import Input, Conv2D, MaxPooling2D, Conv2DTranspose, concatenate
|
| 7 |
+
from keras import backend as K
|
| 8 |
+
from tensorflow.keras import layers, optimizers, callbacks
|
| 9 |
+
from keras.utils import to_categorical
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def build_trans_unet_3class(input_shape=(256, 256, 1), num_classes=3):
|
| 13 |
+
"""
|
| 14 |
+
TransUNet architecture for medical image segmentation
|
| 15 |
+
Combines CNN encoder with Transformer decoder
|
| 16 |
+
"""
|
| 17 |
+
inputs = layers.Input(input_shape)
|
| 18 |
+
|
| 19 |
+
# ==================== CNN ENCODER ====================
|
| 20 |
+
# Stage 1
|
| 21 |
+
conv1 = layers.Conv2D(64, 3, padding='same', activation='relu')(inputs)
|
| 22 |
+
conv1 = layers.Conv2D(64, 3, padding='same', activation='relu')(conv1)
|
| 23 |
+
conv1 = layers.Dropout(0.1)(conv1)
|
| 24 |
+
pool1 = layers.MaxPooling2D(pool_size=(2, 2))(conv1)
|
| 25 |
+
|
| 26 |
+
# Stage 2
|
| 27 |
+
conv2 = layers.Conv2D(128, 3, padding='same', activation='relu')(pool1)
|
| 28 |
+
conv2 = layers.Conv2D(128, 3, padding='same', activation='relu')(conv2)
|
| 29 |
+
conv2 = layers.Dropout(0.1)(conv2)
|
| 30 |
+
pool2 = layers.MaxPooling2D(pool_size=(2, 2))(conv2)
|
| 31 |
+
|
| 32 |
+
# Stage 3
|
| 33 |
+
conv3 = layers.Conv2D(256, 3, padding='same', activation='relu')(pool2)
|
| 34 |
+
conv3 = layers.Conv2D(256, 3, padding='same', activation='relu')(conv3)
|
| 35 |
+
conv3 = layers.Dropout(0.2)(conv3)
|
| 36 |
+
pool3 = layers.MaxPooling2D(pool_size=(2, 2))(conv3)
|
| 37 |
+
|
| 38 |
+
# Stage 4
|
| 39 |
+
conv4 = layers.Conv2D(512, 3, padding='same', activation='relu')(pool3)
|
| 40 |
+
conv4 = layers.Conv2D(512, 3, padding='same', activation='relu')(conv4)
|
| 41 |
+
conv4 = layers.Dropout(0.2)(conv4)
|
| 42 |
+
pool4 = layers.MaxPooling2D(pool_size=(2, 2))(conv4)
|
| 43 |
+
|
| 44 |
+
# ==================== TRANSFORMER BOTTLENECK ====================
|
| 45 |
+
# Bottleneck features: 16x16x512
|
| 46 |
+
bottleneck = layers.Conv2D(768, 3, padding='same', activation='relu')(pool4)
|
| 47 |
+
bottleneck = layers.Dropout(0.3)(bottleneck)
|
| 48 |
+
|
| 49 |
+
# Prepare for transformer: reshape to sequence
|
| 50 |
+
batch_size = tf.shape(bottleneck)[0]
|
| 51 |
+
h, w = 16, 16 # feature map dimensions at bottleneck
|
| 52 |
+
d_model = 768 # transformer dimension
|
| 53 |
+
|
| 54 |
+
# Flatten spatial dimensions for transformer
|
| 55 |
+
transformer_input = layers.Reshape((h * w, d_model))(bottleneck)
|
| 56 |
+
|
| 57 |
+
# Add positional encoding
|
| 58 |
+
positions = tf.range(start=0, limit=h * w, delta=1)
|
| 59 |
+
pos_encoding = layers.Embedding(h * w, d_model)(positions)
|
| 60 |
+
transformer_input = transformer_input + pos_encoding
|
| 61 |
+
|
| 62 |
+
# Multi-head attention blocks
|
| 63 |
+
for _ in range(4): # 4 transformer layers
|
| 64 |
+
# Multi-head attention
|
| 65 |
+
attention_output = layers.MultiHeadAttention(
|
| 66 |
+
num_heads=8, key_dim=d_model // 8, dropout=0.1
|
| 67 |
+
)(transformer_input, transformer_input)
|
| 68 |
+
attention_output = layers.Dropout(0.1)(attention_output)
|
| 69 |
+
transformer_input = layers.LayerNormalization()(transformer_input + attention_output)
|
| 70 |
+
|
| 71 |
+
# Feed forward network
|
| 72 |
+
ffn = layers.Dense(d_model * 2, activation='relu')(transformer_input)
|
| 73 |
+
ffn = layers.Dropout(0.1)(ffn)
|
| 74 |
+
ffn = layers.Dense(d_model)(ffn)
|
| 75 |
+
ffn = layers.Dropout(0.1)(ffn)
|
| 76 |
+
transformer_input = layers.LayerNormalization()(transformer_input + ffn)
|
| 77 |
+
|
| 78 |
+
# Reshape back to spatial format
|
| 79 |
+
transformer_output = layers.Reshape((h, w, d_model))(transformer_input)
|
| 80 |
+
|
| 81 |
+
# Project back to bottleneck dimension
|
| 82 |
+
bottleneck_enhanced = layers.Conv2D(512, 1, activation='relu')(transformer_output)
|
| 83 |
+
bottleneck_enhanced = layers.Dropout(0.3)(bottleneck_enhanced)
|
| 84 |
+
|
| 85 |
+
# ==================== CNN DECODER ====================
|
| 86 |
+
# Decoder Stage 1
|
| 87 |
+
up1 = layers.Conv2DTranspose(512, 2, strides=2, padding='same')(bottleneck_enhanced)
|
| 88 |
+
concat1 = layers.Concatenate()([up1, conv4])
|
| 89 |
+
concat1 = layers.Dropout(0.2)(concat1)
|
| 90 |
+
|
| 91 |
+
conv_up1 = layers.Conv2D(512, 3, padding='same', activation='relu')(concat1)
|
| 92 |
+
conv_up1 = layers.Conv2D(512, 3, padding='same', activation='relu')(conv_up1)
|
| 93 |
+
|
| 94 |
+
# Decoder Stage 2
|
| 95 |
+
up2 = layers.Conv2DTranspose(256, 2, strides=2, padding='same')(conv_up1)
|
| 96 |
+
concat2 = layers.Concatenate()([up2, conv3])
|
| 97 |
+
concat2 = layers.Dropout(0.2)(concat2)
|
| 98 |
+
|
| 99 |
+
conv_up2 = layers.Conv2D(256, 3, padding='same', activation='relu')(concat2)
|
| 100 |
+
conv_up2 = layers.Conv2D(256, 3, padding='same', activation='relu')(conv_up2)
|
| 101 |
+
|
| 102 |
+
# Decoder Stage 3
|
| 103 |
+
up3 = layers.Conv2DTranspose(128, 2, strides=2, padding='same')(conv_up2)
|
| 104 |
+
concat3 = layers.Concatenate()([up3, conv2])
|
| 105 |
+
concat3 = layers.Dropout(0.1)(concat3)
|
| 106 |
+
|
| 107 |
+
conv_up3 = layers.Conv2D(128, 3, padding='same', activation='relu')(concat3)
|
| 108 |
+
conv_up3 = layers.Conv2D(128, 3, padding='same', activation='relu')(conv_up3)
|
| 109 |
+
|
| 110 |
+
# Decoder Stage 4
|
| 111 |
+
up4 = layers.Conv2DTranspose(64, 2, strides=2, padding='same')(conv_up3)
|
| 112 |
+
concat4 = layers.Concatenate()([up4, conv1])
|
| 113 |
+
concat4 = layers.Dropout(0.1)(concat4)
|
| 114 |
+
|
| 115 |
+
conv_up4 = layers.Conv2D(64, 3, padding='same', activation='relu')(concat4)
|
| 116 |
+
conv_up4 = layers.Conv2D(64, 3, padding='same', activation='relu')(conv_up4)
|
| 117 |
+
|
| 118 |
+
# ==================== OUTPUT LAYER ====================
|
| 119 |
+
if num_classes == 1:
|
| 120 |
+
outputs = layers.Conv2D(1, 1, activation='sigmoid')(conv_up4)
|
| 121 |
+
else:
|
| 122 |
+
outputs = layers.Conv2D(num_classes, 1, activation='softmax')(conv_up4)
|
| 123 |
+
|
| 124 |
+
model = tf.keras.Model(inputs, outputs, name='TransUNet')
|
| 125 |
+
return model
|
models/for_WMH_Vent/model_training_scripts/unet_model.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
###################### Libraries ######################
|
| 2 |
+
# Deep Learning
|
| 3 |
+
import keras
|
| 4 |
+
from keras.models import Model
|
| 5 |
+
from keras.layers import Input, Conv2D, MaxPooling2D, Conv2DTranspose, concatenate
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def build_unet_3class(input_shape=(256, 256, 1), num_classes=3):
|
| 9 |
+
"""Enhanced U-Net architecture with batch normalization and dropout"""
|
| 10 |
+
inputs = Input(input_shape)
|
| 11 |
+
|
| 12 |
+
# Encoder with batch normalization
|
| 13 |
+
c1 = Conv2D(64, 3, activation='relu', padding='same')(inputs)
|
| 14 |
+
# c1 = keras.layers.BatchNormalization()(c1)
|
| 15 |
+
c1 = Conv2D(64, 3, activation='relu', padding='same')(c1)
|
| 16 |
+
# c1 = keras.layers.BatchNormalization()(c1)
|
| 17 |
+
p1 = MaxPooling2D()(c1)
|
| 18 |
+
p1 = keras.layers.Dropout(0.1)(p1)
|
| 19 |
+
|
| 20 |
+
c2 = Conv2D(128, 3, activation='relu', padding='same')(p1)
|
| 21 |
+
# c2 = keras.layers.BatchNormalization()(c2)
|
| 22 |
+
c2 = Conv2D(128, 3, activation='relu', padding='same')(c2)
|
| 23 |
+
# c2 = keras.layers.BatchNormalization()(c2)
|
| 24 |
+
p2 = MaxPooling2D()(c2)
|
| 25 |
+
p2 = keras.layers.Dropout(0.1)(p2)
|
| 26 |
+
|
| 27 |
+
c3 = Conv2D(256, 3, activation='relu', padding='same')(p2)
|
| 28 |
+
# c3 = keras.layers.BatchNormalization()(c3)
|
| 29 |
+
c3 = Conv2D(256, 3, activation='relu', padding='same')(c3)
|
| 30 |
+
# c3 = keras.layers.BatchNormalization()(c3)
|
| 31 |
+
p3 = MaxPooling2D()(c3)
|
| 32 |
+
p3 = keras.layers.Dropout(0.2)(p3)
|
| 33 |
+
|
| 34 |
+
c4 = Conv2D(512, 3, activation='relu', padding='same')(p3)
|
| 35 |
+
# c4 = keras.layers.BatchNormalization()(c4)
|
| 36 |
+
c4 = Conv2D(512, 3, activation='relu', padding='same')(c4)
|
| 37 |
+
# c4 = keras.layers.BatchNormalization()(c4)
|
| 38 |
+
p4 = MaxPooling2D()(c4)
|
| 39 |
+
p4 = keras.layers.Dropout(0.2)(p4)
|
| 40 |
+
|
| 41 |
+
# Bottleneck
|
| 42 |
+
c5 = Conv2D(1024, 3, activation='relu', padding='same')(p4)
|
| 43 |
+
# c5 = keras.layers.BatchNormalization()(c5)
|
| 44 |
+
c5 = Conv2D(1024, 3, activation='relu', padding='same')(c5)
|
| 45 |
+
# c5 = keras.layers.BatchNormalization()(c5)
|
| 46 |
+
c5 = keras.layers.Dropout(0.3)(c5)
|
| 47 |
+
|
| 48 |
+
# Decoder
|
| 49 |
+
u6 = Conv2DTranspose(512, 2, strides=2, padding='same')(c5)
|
| 50 |
+
u6 = concatenate([u6, c4])
|
| 51 |
+
u6 = keras.layers.Dropout(0.2)(u6)
|
| 52 |
+
c6 = Conv2D(512, 3, activation='relu', padding='same')(u6)
|
| 53 |
+
# c6 = keras.layers.BatchNormalization()(c6)
|
| 54 |
+
c6 = Conv2D(512, 3, activation='relu', padding='same')(c6)
|
| 55 |
+
# c6 = keras.layers.BatchNormalization()(c6)
|
| 56 |
+
|
| 57 |
+
u7 = Conv2DTranspose(256, 2, strides=2, padding='same')(c6)
|
| 58 |
+
u7 = concatenate([u7, c3])
|
| 59 |
+
u7 = keras.layers.Dropout(0.2)(u7)
|
| 60 |
+
c7 = Conv2D(256, 3, activation='relu', padding='same')(u7)
|
| 61 |
+
# c7 = keras.layers.BatchNormalization()(c7)
|
| 62 |
+
c7 = Conv2D(256, 3, activation='relu', padding='same')(c7)
|
| 63 |
+
# c7 = keras.layers.BatchNormalization()(c7)
|
| 64 |
+
|
| 65 |
+
u8 = Conv2DTranspose(128, 2, strides=2, padding='same')(c7)
|
| 66 |
+
u8 = concatenate([u8, c2])
|
| 67 |
+
u8 = keras.layers.Dropout(0.1)(u8)
|
| 68 |
+
c8 = Conv2D(128, 3, activation='relu', padding='same')(u8)
|
| 69 |
+
# c8 = keras.layers.BatchNormalization()(c8)
|
| 70 |
+
c8 = Conv2D(128, 3, activation='relu', padding='same')(c8)
|
| 71 |
+
# c8 = keras.layers.BatchNormalization()(c8)
|
| 72 |
+
|
| 73 |
+
u9 = Conv2DTranspose(64, 2, strides=2, padding='same')(c8)
|
| 74 |
+
u9 = concatenate([u9, c1])
|
| 75 |
+
u9 = keras.layers.Dropout(0.1)(u9)
|
| 76 |
+
c9 = Conv2D(64, 3, activation='relu', padding='same')(u9)
|
| 77 |
+
# c9 = keras.layers.BatchNormalization()(c9)
|
| 78 |
+
c9 = Conv2D(64, 3, activation='relu', padding='same')(c9)
|
| 79 |
+
# c9 = keras.layers.BatchNormalization()(c9)
|
| 80 |
+
|
| 81 |
+
# Output layer
|
| 82 |
+
if num_classes == 1:
|
| 83 |
+
outputs = Conv2D(1, 1, activation='sigmoid')(c9)
|
| 84 |
+
else:
|
| 85 |
+
outputs = Conv2D(num_classes, 1, activation='softmax')(c9)
|
| 86 |
+
|
| 87 |
+
return Model(inputs, outputs, name='UNet')
|
models/for_WMH_Vent/model_training_scripts/utility_functions.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
P4 Article - Utility Functions
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
Developer:
|
| 6 |
+
"Mahdi Bashiri Bawil"
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import gc
|
| 10 |
+
import tensorflow as tf
|
| 11 |
+
from tensorflow.keras import backend as K
|
| 12 |
+
|
| 13 |
+
print("TensorFlow Version:", tf.__version__)
|
| 14 |
+
|
| 15 |
+
###################### GPU Configuration ######################
|
| 16 |
+
|
| 17 |
+
# Configure GPU memory growth
|
| 18 |
+
physical_devices = tf.config.list_physical_devices('GPU')
|
| 19 |
+
if physical_devices:
|
| 20 |
+
try:
|
| 21 |
+
for device in physical_devices:
|
| 22 |
+
tf.config.experimental.set_memory_growth(device, True)
|
| 23 |
+
print("✅ GPU memory growth enabled")
|
| 24 |
+
print(f" Available GPUs: {len(physical_devices)}")
|
| 25 |
+
except RuntimeError as e:
|
| 26 |
+
print(f"GPU configuration error: {e}")
|
| 27 |
+
else:
|
| 28 |
+
print("⚠️ No GPU detected - training will be slow")
|
| 29 |
+
|
| 30 |
+
"""
|
| 31 |
+
GPU Memory Management for Sequential Experiments
|
| 32 |
+
To properly release memory between experiments
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def clear_gpu_memory():
|
| 37 |
+
"""
|
| 38 |
+
Comprehensive GPU memory cleanup between experiments
|
| 39 |
+
Call this after each experiment completes
|
| 40 |
+
"""
|
| 41 |
+
print("\n" + "="*70)
|
| 42 |
+
print("CLEANING UP GPU MEMORY")
|
| 43 |
+
print("="*70)
|
| 44 |
+
|
| 45 |
+
# Clear Keras session
|
| 46 |
+
K.clear_session()
|
| 47 |
+
print("✅ Cleared Keras session")
|
| 48 |
+
|
| 49 |
+
# Force garbage collection multiple times
|
| 50 |
+
for _ in range(3):
|
| 51 |
+
gc.collect()
|
| 52 |
+
print("✅ Ran garbage collection (3 passes)")
|
| 53 |
+
|
| 54 |
+
# Reset TensorFlow graphs
|
| 55 |
+
tf.compat.v1.reset_default_graph()
|
| 56 |
+
print("✅ Reset default graph")
|
| 57 |
+
|
| 58 |
+
# Additional cleanup for TF 2.x
|
| 59 |
+
try:
|
| 60 |
+
# Clear any cached tensors
|
| 61 |
+
tf.config.experimental.reset_memory_stats('GPU:0')
|
| 62 |
+
print("✅ Reset GPU memory stats")
|
| 63 |
+
except:
|
| 64 |
+
pass
|
| 65 |
+
|
| 66 |
+
# CRITICAL: Reset GPU memory allocator
|
| 67 |
+
# This forces TensorFlow to release memory back to the system
|
| 68 |
+
try:
|
| 69 |
+
physical_devices = tf.config.list_physical_devices('GPU')
|
| 70 |
+
if physical_devices:
|
| 71 |
+
# Disable and re-enable memory growth to flush allocator
|
| 72 |
+
for device in physical_devices:
|
| 73 |
+
tf.config.experimental.set_memory_growth(device, False)
|
| 74 |
+
tf.config.experimental.set_memory_growth(device, True)
|
| 75 |
+
print("✅ Reset memory growth (flushed allocator)")
|
| 76 |
+
except Exception as e:
|
| 77 |
+
print(f"⚠️ Could not reset memory growth: {e}")
|
| 78 |
+
|
| 79 |
+
print("="*70 + "\n")
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def get_gpu_memory_info():
|
| 83 |
+
"""
|
| 84 |
+
Print current GPU memory usage
|
| 85 |
+
Useful for monitoring memory leaks
|
| 86 |
+
"""
|
| 87 |
+
try:
|
| 88 |
+
gpu_devices = tf.config.list_physical_devices('GPU')
|
| 89 |
+
if gpu_devices:
|
| 90 |
+
for device in gpu_devices:
|
| 91 |
+
details = tf.config.experimental.get_memory_info(device.name.replace('/physical_device:', ''))
|
| 92 |
+
current_mb = details['current'] / 1024**2
|
| 93 |
+
peak_mb = details['peak'] / 1024**2
|
| 94 |
+
print(f"GPU Memory - Current: {current_mb:.1f} MB, Peak: {peak_mb:.1f} MB")
|
| 95 |
+
except Exception as e:
|
| 96 |
+
print(f"Could not get GPU memory info: {e}")
|
models/for_WMH_Vent/results_fold_avg_var_1_zscore2/models/standard_3class/download_models.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Visit our Hugging Face link for downloading the trained models.
|