Upload 33 files
Browse files- .gitattributes +20 -0
- models/for_GM/class_weights_gm/class_weights_fold0_standard_binary.json +23 -0
- models/for_GM/data_splits_sp_gm/SP_GM_fold_assignments.json +1295 -0
- models/for_GM/model_training_scripts/p1_compute_class_weights.py +336 -0
- models/for_GM/model_training_scripts/p1_data_loader.py +847 -0
- models/for_GM/model_training_scripts/p1_pix2pix_var5.py +1313 -0
- models/for_GM/model_training_scripts/p1_predict_new_data_gm.py +477 -0
- models/for_GM/model_training_scripts/unet_model.py +87 -0
- models/for_GM/model_training_scripts/utility_functions.py +97 -0
- models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_001_1.png +3 -0
- models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_001_2.png +3 -0
- models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_002_1.png +3 -0
- models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_002_2.png +3 -0
- models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_003_1.png +3 -0
- models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_003_2.png +3 -0
- models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_004_1.png +3 -0
- models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_004_2.png +3 -0
- models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_005_1.png +3 -0
- models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_005_2.png +3 -0
- models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_006_1.png +3 -0
- models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_006_2.png +3 -0
- models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_007_1.png +3 -0
- models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_007_2.png +3 -0
- models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_008_1.png +3 -0
- models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_008_2.png +3 -0
- models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_009_1.png +3 -0
- models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_009_2.png +3 -0
- models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_010_1.png +3 -0
- models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_010_2.png +3 -0
- models/for_GM/results_fold_0_var_1_bet_zscore_gm/models/standard_binary/fold_0/best_dice_discriminator.h5 +3 -0
- models/for_GM/results_fold_0_var_1_bet_zscore_gm/models/standard_binary/fold_0/best_dice_generator.h5 +3 -0
- models/for_GM/results_fold_0_var_1_bet_zscore_gm/models/standard_binary/fold_0/config.json +19 -0
- models/for_GM/results_fold_0_var_1_bet_zscore_gm/models/standard_binary/fold_0/download_models.txt +1 -0
- models/for_GM/results_fold_0_var_1_bet_zscore_gm/models/standard_binary/fold_0/history.json +145 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,23 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_001_1.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_001_2.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_002_1.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_002_2.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_003_1.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_003_2.png filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_004_1.png filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_004_2.png filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_005_1.png filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_005_2.png filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_006_1.png filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_006_2.png filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_007_1.png filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_007_2.png filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_008_1.png filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_008_2.png filter=lfs diff=lfs merge=lfs -text
|
| 52 |
+
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_009_1.png filter=lfs diff=lfs merge=lfs -text
|
| 53 |
+
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_009_2.png filter=lfs diff=lfs merge=lfs -text
|
| 54 |
+
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_010_1.png filter=lfs diff=lfs merge=lfs -text
|
| 55 |
+
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_010_2.png filter=lfs diff=lfs merge=lfs -text
|
models/for_GM/class_weights_gm/class_weights_fold0_standard_binary.json
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"fold_id": 0,
|
| 3 |
+
"class_scenario": "binary",
|
| 4 |
+
"preprocessing": "standard",
|
| 5 |
+
"num_classes": 2,
|
| 6 |
+
"total_pixels": 88539136,
|
| 7 |
+
"class_pixel_counts": [
|
| 8 |
+
79575838,
|
| 9 |
+
8963298
|
| 10 |
+
],
|
| 11 |
+
"class_frequencies": [
|
| 12 |
+
0.8987645644068629,
|
| 13 |
+
0.10123543559313702
|
| 14 |
+
],
|
| 15 |
+
"class_weights": [
|
| 16 |
+
0.20247246624134155,
|
| 17 |
+
1.7975275337586585
|
| 18 |
+
],
|
| 19 |
+
"class_names": [
|
| 20 |
+
"Background",
|
| 21 |
+
"Specialized GM"
|
| 22 |
+
]
|
| 23 |
+
}
|
models/for_GM/data_splits_sp_gm/SP_GM_fold_assignments.json
ADDED
|
@@ -0,0 +1,1295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"metadata": {
|
| 3 |
+
"total_patients": 268,
|
| 4 |
+
"test_patients": 26,
|
| 5 |
+
"trainval_patients": 242,
|
| 6 |
+
"n_folds": 5,
|
| 7 |
+
"random_seed": 42,
|
| 8 |
+
"datasets": [
|
| 9 |
+
"Local_SAI_GM_sp"
|
| 10 |
+
]
|
| 11 |
+
},
|
| 12 |
+
"test_set": {
|
| 13 |
+
"patients": [
|
| 14 |
+
"117524",
|
| 15 |
+
"132287",
|
| 16 |
+
"105597",
|
| 17 |
+
"120429",
|
| 18 |
+
"117949",
|
| 19 |
+
"126395",
|
| 20 |
+
"134240",
|
| 21 |
+
"120907",
|
| 22 |
+
"106506",
|
| 23 |
+
"110784",
|
| 24 |
+
"118754",
|
| 25 |
+
"112997",
|
| 26 |
+
"112730",
|
| 27 |
+
"129466",
|
| 28 |
+
"105911",
|
| 29 |
+
"111008",
|
| 30 |
+
"129008",
|
| 31 |
+
"129044",
|
| 32 |
+
"110543",
|
| 33 |
+
"117276",
|
| 34 |
+
"114454",
|
| 35 |
+
"104474",
|
| 36 |
+
"114770",
|
| 37 |
+
"130578",
|
| 38 |
+
"116740",
|
| 39 |
+
"107680"
|
| 40 |
+
],
|
| 41 |
+
"n_patients": 26
|
| 42 |
+
},
|
| 43 |
+
"folds": {
|
| 44 |
+
"fold_0": {
|
| 45 |
+
"train_patients": [
|
| 46 |
+
"101228",
|
| 47 |
+
"101627",
|
| 48 |
+
"102035",
|
| 49 |
+
"102313",
|
| 50 |
+
"104252",
|
| 51 |
+
"104280",
|
| 52 |
+
"104447",
|
| 53 |
+
"104453",
|
| 54 |
+
"104670",
|
| 55 |
+
"104797",
|
| 56 |
+
"104810",
|
| 57 |
+
"104871",
|
| 58 |
+
"105074",
|
| 59 |
+
"105549",
|
| 60 |
+
"105755",
|
| 61 |
+
"105917",
|
| 62 |
+
"105978",
|
| 63 |
+
"106270",
|
| 64 |
+
"106536",
|
| 65 |
+
"106639",
|
| 66 |
+
"106780",
|
| 67 |
+
"106976",
|
| 68 |
+
"107130",
|
| 69 |
+
"107455",
|
| 70 |
+
"107508",
|
| 71 |
+
"107539",
|
| 72 |
+
"107630",
|
| 73 |
+
"107966",
|
| 74 |
+
"107997",
|
| 75 |
+
"108295",
|
| 76 |
+
"108344",
|
| 77 |
+
"108444",
|
| 78 |
+
"108726",
|
| 79 |
+
"108975",
|
| 80 |
+
"109141",
|
| 81 |
+
"109267",
|
| 82 |
+
"109395",
|
| 83 |
+
"109654",
|
| 84 |
+
"109816",
|
| 85 |
+
"109923",
|
| 86 |
+
"109944",
|
| 87 |
+
"110012",
|
| 88 |
+
"110157",
|
| 89 |
+
"110218",
|
| 90 |
+
"110280",
|
| 91 |
+
"110327",
|
| 92 |
+
"110497",
|
| 93 |
+
"111140",
|
| 94 |
+
"111189",
|
| 95 |
+
"111489",
|
| 96 |
+
"111691",
|
| 97 |
+
"111852",
|
| 98 |
+
"112414",
|
| 99 |
+
"112657",
|
| 100 |
+
"112659",
|
| 101 |
+
"112765",
|
| 102 |
+
"112776",
|
| 103 |
+
"113394",
|
| 104 |
+
"114058",
|
| 105 |
+
"114128",
|
| 106 |
+
"114266",
|
| 107 |
+
"114304",
|
| 108 |
+
"114525",
|
| 109 |
+
"114585",
|
| 110 |
+
"114903",
|
| 111 |
+
"114990",
|
| 112 |
+
"115588",
|
| 113 |
+
"115628",
|
| 114 |
+
"115788",
|
| 115 |
+
"115799",
|
| 116 |
+
"115841",
|
| 117 |
+
"115991",
|
| 118 |
+
"116236",
|
| 119 |
+
"116246",
|
| 120 |
+
"116577",
|
| 121 |
+
"116700",
|
| 122 |
+
"116914",
|
| 123 |
+
"116937",
|
| 124 |
+
"117314",
|
| 125 |
+
"117385",
|
| 126 |
+
"117814",
|
| 127 |
+
"118018",
|
| 128 |
+
"118078",
|
| 129 |
+
"118409",
|
| 130 |
+
"118450",
|
| 131 |
+
"118481",
|
| 132 |
+
"118605",
|
| 133 |
+
"118719",
|
| 134 |
+
"118755",
|
| 135 |
+
"119730",
|
| 136 |
+
"120638",
|
| 137 |
+
"120749",
|
| 138 |
+
"120857",
|
| 139 |
+
"121140",
|
| 140 |
+
"121404",
|
| 141 |
+
"121499",
|
| 142 |
+
"121620",
|
| 143 |
+
"121804",
|
| 144 |
+
"121921",
|
| 145 |
+
"122000",
|
| 146 |
+
"122316",
|
| 147 |
+
"122762",
|
| 148 |
+
"122884",
|
| 149 |
+
"123575",
|
| 150 |
+
"124187",
|
| 151 |
+
"124899",
|
| 152 |
+
"125198",
|
| 153 |
+
"125465",
|
| 154 |
+
"125567",
|
| 155 |
+
"125798",
|
| 156 |
+
"126228",
|
| 157 |
+
"126396",
|
| 158 |
+
"126445",
|
| 159 |
+
"126465",
|
| 160 |
+
"126494",
|
| 161 |
+
"126523",
|
| 162 |
+
"126542",
|
| 163 |
+
"126704",
|
| 164 |
+
"126768",
|
| 165 |
+
"126779",
|
| 166 |
+
"127096",
|
| 167 |
+
"127513",
|
| 168 |
+
"127758",
|
| 169 |
+
"127816",
|
| 170 |
+
"127897",
|
| 171 |
+
"128785",
|
| 172 |
+
"128901",
|
| 173 |
+
"129055",
|
| 174 |
+
"129100",
|
| 175 |
+
"129637",
|
| 176 |
+
"129739",
|
| 177 |
+
"130214",
|
| 178 |
+
"130282",
|
| 179 |
+
"130366",
|
| 180 |
+
"130371",
|
| 181 |
+
"130402",
|
| 182 |
+
"130556",
|
| 183 |
+
"130662",
|
| 184 |
+
"130801",
|
| 185 |
+
"131040",
|
| 186 |
+
"131231",
|
| 187 |
+
"131235",
|
| 188 |
+
"131364",
|
| 189 |
+
"131444",
|
| 190 |
+
"131494",
|
| 191 |
+
"131606",
|
| 192 |
+
"131636",
|
| 193 |
+
"131792",
|
| 194 |
+
"131924",
|
| 195 |
+
"132155",
|
| 196 |
+
"132207",
|
| 197 |
+
"132282",
|
| 198 |
+
"132296",
|
| 199 |
+
"132589",
|
| 200 |
+
"132597",
|
| 201 |
+
"132605",
|
| 202 |
+
"132920",
|
| 203 |
+
"133196",
|
| 204 |
+
"133338",
|
| 205 |
+
"133562",
|
| 206 |
+
"133710",
|
| 207 |
+
"133814",
|
| 208 |
+
"133850",
|
| 209 |
+
"133886",
|
| 210 |
+
"133934",
|
| 211 |
+
"133946",
|
| 212 |
+
"134032",
|
| 213 |
+
"134654",
|
| 214 |
+
"134728",
|
| 215 |
+
"134919",
|
| 216 |
+
"134955",
|
| 217 |
+
"135467",
|
| 218 |
+
"135503",
|
| 219 |
+
"135687",
|
| 220 |
+
"135695",
|
| 221 |
+
"135697",
|
| 222 |
+
"135725",
|
| 223 |
+
"135733",
|
| 224 |
+
"135830",
|
| 225 |
+
"135855",
|
| 226 |
+
"136104",
|
| 227 |
+
"136105",
|
| 228 |
+
"136175",
|
| 229 |
+
"136220",
|
| 230 |
+
"136310",
|
| 231 |
+
"136382",
|
| 232 |
+
"136793",
|
| 233 |
+
"136817",
|
| 234 |
+
"136966",
|
| 235 |
+
"136996",
|
| 236 |
+
"137104",
|
| 237 |
+
"137508",
|
| 238 |
+
"137675"
|
| 239 |
+
],
|
| 240 |
+
"val_patients": [
|
| 241 |
+
"104420",
|
| 242 |
+
"104518",
|
| 243 |
+
"104520",
|
| 244 |
+
"104899",
|
| 245 |
+
"104937",
|
| 246 |
+
"105302",
|
| 247 |
+
"105465",
|
| 248 |
+
"106063",
|
| 249 |
+
"106200",
|
| 250 |
+
"106905",
|
| 251 |
+
"107233",
|
| 252 |
+
"107739",
|
| 253 |
+
"108807",
|
| 254 |
+
"110540",
|
| 255 |
+
"112055",
|
| 256 |
+
"112378",
|
| 257 |
+
"113046",
|
| 258 |
+
"113845",
|
| 259 |
+
"114836",
|
| 260 |
+
"116268",
|
| 261 |
+
"116768",
|
| 262 |
+
"118660",
|
| 263 |
+
"118807",
|
| 264 |
+
"119095",
|
| 265 |
+
"119224",
|
| 266 |
+
"120781",
|
| 267 |
+
"122020",
|
| 268 |
+
"122288",
|
| 269 |
+
"125626",
|
| 270 |
+
"127511",
|
| 271 |
+
"127545",
|
| 272 |
+
"127870",
|
| 273 |
+
"129164",
|
| 274 |
+
"129916",
|
| 275 |
+
"130308",
|
| 276 |
+
"130373",
|
| 277 |
+
"131919",
|
| 278 |
+
"132371",
|
| 279 |
+
"132812",
|
| 280 |
+
"132896",
|
| 281 |
+
"133340",
|
| 282 |
+
"134197",
|
| 283 |
+
"134555",
|
| 284 |
+
"135628",
|
| 285 |
+
"136144",
|
| 286 |
+
"136589",
|
| 287 |
+
"137168",
|
| 288 |
+
"137617",
|
| 289 |
+
"137624"
|
| 290 |
+
],
|
| 291 |
+
"n_train": 193,
|
| 292 |
+
"n_val": 49
|
| 293 |
+
},
|
| 294 |
+
"fold_1": {
|
| 295 |
+
"train_patients": [
|
| 296 |
+
"101228",
|
| 297 |
+
"101627",
|
| 298 |
+
"102035",
|
| 299 |
+
"102313",
|
| 300 |
+
"104252",
|
| 301 |
+
"104420",
|
| 302 |
+
"104447",
|
| 303 |
+
"104453",
|
| 304 |
+
"104518",
|
| 305 |
+
"104520",
|
| 306 |
+
"104670",
|
| 307 |
+
"104810",
|
| 308 |
+
"104871",
|
| 309 |
+
"104899",
|
| 310 |
+
"104937",
|
| 311 |
+
"105074",
|
| 312 |
+
"105302",
|
| 313 |
+
"105465",
|
| 314 |
+
"105549",
|
| 315 |
+
"105755",
|
| 316 |
+
"105917",
|
| 317 |
+
"105978",
|
| 318 |
+
"106063",
|
| 319 |
+
"106200",
|
| 320 |
+
"106270",
|
| 321 |
+
"106536",
|
| 322 |
+
"106905",
|
| 323 |
+
"107130",
|
| 324 |
+
"107233",
|
| 325 |
+
"107455",
|
| 326 |
+
"107539",
|
| 327 |
+
"107630",
|
| 328 |
+
"107739",
|
| 329 |
+
"107966",
|
| 330 |
+
"107997",
|
| 331 |
+
"108295",
|
| 332 |
+
"108444",
|
| 333 |
+
"108726",
|
| 334 |
+
"108807",
|
| 335 |
+
"108975",
|
| 336 |
+
"109141",
|
| 337 |
+
"109267",
|
| 338 |
+
"109395",
|
| 339 |
+
"109654",
|
| 340 |
+
"109923",
|
| 341 |
+
"109944",
|
| 342 |
+
"110012",
|
| 343 |
+
"110280",
|
| 344 |
+
"110327",
|
| 345 |
+
"110497",
|
| 346 |
+
"110540",
|
| 347 |
+
"111140",
|
| 348 |
+
"111189",
|
| 349 |
+
"111489",
|
| 350 |
+
"111691",
|
| 351 |
+
"112055",
|
| 352 |
+
"112378",
|
| 353 |
+
"112659",
|
| 354 |
+
"112765",
|
| 355 |
+
"112776",
|
| 356 |
+
"113046",
|
| 357 |
+
"113394",
|
| 358 |
+
"113845",
|
| 359 |
+
"114058",
|
| 360 |
+
"114128",
|
| 361 |
+
"114266",
|
| 362 |
+
"114525",
|
| 363 |
+
"114585",
|
| 364 |
+
"114836",
|
| 365 |
+
"114903",
|
| 366 |
+
"115588",
|
| 367 |
+
"115788",
|
| 368 |
+
"115799",
|
| 369 |
+
"115841",
|
| 370 |
+
"115991",
|
| 371 |
+
"116236",
|
| 372 |
+
"116246",
|
| 373 |
+
"116268",
|
| 374 |
+
"116577",
|
| 375 |
+
"116768",
|
| 376 |
+
"116937",
|
| 377 |
+
"117314",
|
| 378 |
+
"117385",
|
| 379 |
+
"118018",
|
| 380 |
+
"118078",
|
| 381 |
+
"118450",
|
| 382 |
+
"118481",
|
| 383 |
+
"118605",
|
| 384 |
+
"118660",
|
| 385 |
+
"118755",
|
| 386 |
+
"118807",
|
| 387 |
+
"119095",
|
| 388 |
+
"119224",
|
| 389 |
+
"120749",
|
| 390 |
+
"120781",
|
| 391 |
+
"120857",
|
| 392 |
+
"121499",
|
| 393 |
+
"121620",
|
| 394 |
+
"121804",
|
| 395 |
+
"122020",
|
| 396 |
+
"122288",
|
| 397 |
+
"122316",
|
| 398 |
+
"122762",
|
| 399 |
+
"122884",
|
| 400 |
+
"123575",
|
| 401 |
+
"124899",
|
| 402 |
+
"125198",
|
| 403 |
+
"125465",
|
| 404 |
+
"125567",
|
| 405 |
+
"125626",
|
| 406 |
+
"125798",
|
| 407 |
+
"126396",
|
| 408 |
+
"126465",
|
| 409 |
+
"126494",
|
| 410 |
+
"126542",
|
| 411 |
+
"126704",
|
| 412 |
+
"126779",
|
| 413 |
+
"127096",
|
| 414 |
+
"127511",
|
| 415 |
+
"127513",
|
| 416 |
+
"127545",
|
| 417 |
+
"127870",
|
| 418 |
+
"127897",
|
| 419 |
+
"128785",
|
| 420 |
+
"129055",
|
| 421 |
+
"129100",
|
| 422 |
+
"129164",
|
| 423 |
+
"129739",
|
| 424 |
+
"129916",
|
| 425 |
+
"130214",
|
| 426 |
+
"130282",
|
| 427 |
+
"130308",
|
| 428 |
+
"130371",
|
| 429 |
+
"130373",
|
| 430 |
+
"130402",
|
| 431 |
+
"130556",
|
| 432 |
+
"130662",
|
| 433 |
+
"130801",
|
| 434 |
+
"131231",
|
| 435 |
+
"131444",
|
| 436 |
+
"131494",
|
| 437 |
+
"131606",
|
| 438 |
+
"131636",
|
| 439 |
+
"131792",
|
| 440 |
+
"131919",
|
| 441 |
+
"131924",
|
| 442 |
+
"132155",
|
| 443 |
+
"132207",
|
| 444 |
+
"132282",
|
| 445 |
+
"132296",
|
| 446 |
+
"132371",
|
| 447 |
+
"132589",
|
| 448 |
+
"132812",
|
| 449 |
+
"132896",
|
| 450 |
+
"132920",
|
| 451 |
+
"133196",
|
| 452 |
+
"133340",
|
| 453 |
+
"133562",
|
| 454 |
+
"133710",
|
| 455 |
+
"133814",
|
| 456 |
+
"133850",
|
| 457 |
+
"133886",
|
| 458 |
+
"134032",
|
| 459 |
+
"134197",
|
| 460 |
+
"134555",
|
| 461 |
+
"134654",
|
| 462 |
+
"134728",
|
| 463 |
+
"134919",
|
| 464 |
+
"134955",
|
| 465 |
+
"135467",
|
| 466 |
+
"135503",
|
| 467 |
+
"135628",
|
| 468 |
+
"135687",
|
| 469 |
+
"135695",
|
| 470 |
+
"135697",
|
| 471 |
+
"135725",
|
| 472 |
+
"135733",
|
| 473 |
+
"135830",
|
| 474 |
+
"136104",
|
| 475 |
+
"136144",
|
| 476 |
+
"136175",
|
| 477 |
+
"136220",
|
| 478 |
+
"136589",
|
| 479 |
+
"136793",
|
| 480 |
+
"136817",
|
| 481 |
+
"136966",
|
| 482 |
+
"136996",
|
| 483 |
+
"137104",
|
| 484 |
+
"137168",
|
| 485 |
+
"137508",
|
| 486 |
+
"137617",
|
| 487 |
+
"137624",
|
| 488 |
+
"137675"
|
| 489 |
+
],
|
| 490 |
+
"val_patients": [
|
| 491 |
+
"104280",
|
| 492 |
+
"104797",
|
| 493 |
+
"106639",
|
| 494 |
+
"106780",
|
| 495 |
+
"106976",
|
| 496 |
+
"107508",
|
| 497 |
+
"108344",
|
| 498 |
+
"109816",
|
| 499 |
+
"110157",
|
| 500 |
+
"110218",
|
| 501 |
+
"111852",
|
| 502 |
+
"112414",
|
| 503 |
+
"112657",
|
| 504 |
+
"114304",
|
| 505 |
+
"114990",
|
| 506 |
+
"115628",
|
| 507 |
+
"116700",
|
| 508 |
+
"116914",
|
| 509 |
+
"117814",
|
| 510 |
+
"118409",
|
| 511 |
+
"118719",
|
| 512 |
+
"119730",
|
| 513 |
+
"120638",
|
| 514 |
+
"121140",
|
| 515 |
+
"121404",
|
| 516 |
+
"121921",
|
| 517 |
+
"122000",
|
| 518 |
+
"124187",
|
| 519 |
+
"126228",
|
| 520 |
+
"126445",
|
| 521 |
+
"126523",
|
| 522 |
+
"126768",
|
| 523 |
+
"127758",
|
| 524 |
+
"127816",
|
| 525 |
+
"128901",
|
| 526 |
+
"129637",
|
| 527 |
+
"130366",
|
| 528 |
+
"131040",
|
| 529 |
+
"131235",
|
| 530 |
+
"131364",
|
| 531 |
+
"132597",
|
| 532 |
+
"132605",
|
| 533 |
+
"133338",
|
| 534 |
+
"133934",
|
| 535 |
+
"133946",
|
| 536 |
+
"135855",
|
| 537 |
+
"136105",
|
| 538 |
+
"136310",
|
| 539 |
+
"136382"
|
| 540 |
+
],
|
| 541 |
+
"n_train": 193,
|
| 542 |
+
"n_val": 49
|
| 543 |
+
},
|
| 544 |
+
"fold_2": {
|
| 545 |
+
"train_patients": [
|
| 546 |
+
"101627",
|
| 547 |
+
"102313",
|
| 548 |
+
"104280",
|
| 549 |
+
"104420",
|
| 550 |
+
"104447",
|
| 551 |
+
"104453",
|
| 552 |
+
"104518",
|
| 553 |
+
"104520",
|
| 554 |
+
"104797",
|
| 555 |
+
"104810",
|
| 556 |
+
"104871",
|
| 557 |
+
"104899",
|
| 558 |
+
"104937",
|
| 559 |
+
"105074",
|
| 560 |
+
"105302",
|
| 561 |
+
"105465",
|
| 562 |
+
"105549",
|
| 563 |
+
"105755",
|
| 564 |
+
"105978",
|
| 565 |
+
"106063",
|
| 566 |
+
"106200",
|
| 567 |
+
"106639",
|
| 568 |
+
"106780",
|
| 569 |
+
"106905",
|
| 570 |
+
"106976",
|
| 571 |
+
"107233",
|
| 572 |
+
"107455",
|
| 573 |
+
"107508",
|
| 574 |
+
"107630",
|
| 575 |
+
"107739",
|
| 576 |
+
"107966",
|
| 577 |
+
"107997",
|
| 578 |
+
"108344",
|
| 579 |
+
"108444",
|
| 580 |
+
"108726",
|
| 581 |
+
"108807",
|
| 582 |
+
"109141",
|
| 583 |
+
"109267",
|
| 584 |
+
"109395",
|
| 585 |
+
"109654",
|
| 586 |
+
"109816",
|
| 587 |
+
"109923",
|
| 588 |
+
"109944",
|
| 589 |
+
"110012",
|
| 590 |
+
"110157",
|
| 591 |
+
"110218",
|
| 592 |
+
"110280",
|
| 593 |
+
"110327",
|
| 594 |
+
"110497",
|
| 595 |
+
"110540",
|
| 596 |
+
"111489",
|
| 597 |
+
"111691",
|
| 598 |
+
"111852",
|
| 599 |
+
"112055",
|
| 600 |
+
"112378",
|
| 601 |
+
"112414",
|
| 602 |
+
"112657",
|
| 603 |
+
"112765",
|
| 604 |
+
"112776",
|
| 605 |
+
"113046",
|
| 606 |
+
"113394",
|
| 607 |
+
"113845",
|
| 608 |
+
"114304",
|
| 609 |
+
"114525",
|
| 610 |
+
"114585",
|
| 611 |
+
"114836",
|
| 612 |
+
"114903",
|
| 613 |
+
"114990",
|
| 614 |
+
"115628",
|
| 615 |
+
"115788",
|
| 616 |
+
"115799",
|
| 617 |
+
"115841",
|
| 618 |
+
"116236",
|
| 619 |
+
"116246",
|
| 620 |
+
"116268",
|
| 621 |
+
"116577",
|
| 622 |
+
"116700",
|
| 623 |
+
"116768",
|
| 624 |
+
"116914",
|
| 625 |
+
"117314",
|
| 626 |
+
"117814",
|
| 627 |
+
"118018",
|
| 628 |
+
"118078",
|
| 629 |
+
"118409",
|
| 630 |
+
"118450",
|
| 631 |
+
"118481",
|
| 632 |
+
"118605",
|
| 633 |
+
"118660",
|
| 634 |
+
"118719",
|
| 635 |
+
"118755",
|
| 636 |
+
"118807",
|
| 637 |
+
"119095",
|
| 638 |
+
"119224",
|
| 639 |
+
"119730",
|
| 640 |
+
"120638",
|
| 641 |
+
"120749",
|
| 642 |
+
"120781",
|
| 643 |
+
"121140",
|
| 644 |
+
"121404",
|
| 645 |
+
"121499",
|
| 646 |
+
"121804",
|
| 647 |
+
"121921",
|
| 648 |
+
"122000",
|
| 649 |
+
"122020",
|
| 650 |
+
"122288",
|
| 651 |
+
"122762",
|
| 652 |
+
"122884",
|
| 653 |
+
"123575",
|
| 654 |
+
"124187",
|
| 655 |
+
"124899",
|
| 656 |
+
"125198",
|
| 657 |
+
"125626",
|
| 658 |
+
"126228",
|
| 659 |
+
"126445",
|
| 660 |
+
"126523",
|
| 661 |
+
"126542",
|
| 662 |
+
"126768",
|
| 663 |
+
"127096",
|
| 664 |
+
"127511",
|
| 665 |
+
"127513",
|
| 666 |
+
"127545",
|
| 667 |
+
"127758",
|
| 668 |
+
"127816",
|
| 669 |
+
"127870",
|
| 670 |
+
"127897",
|
| 671 |
+
"128785",
|
| 672 |
+
"128901",
|
| 673 |
+
"129100",
|
| 674 |
+
"129164",
|
| 675 |
+
"129637",
|
| 676 |
+
"129739",
|
| 677 |
+
"129916",
|
| 678 |
+
"130214",
|
| 679 |
+
"130282",
|
| 680 |
+
"130308",
|
| 681 |
+
"130366",
|
| 682 |
+
"130371",
|
| 683 |
+
"130373",
|
| 684 |
+
"130402",
|
| 685 |
+
"130801",
|
| 686 |
+
"131040",
|
| 687 |
+
"131231",
|
| 688 |
+
"131235",
|
| 689 |
+
"131364",
|
| 690 |
+
"131444",
|
| 691 |
+
"131494",
|
| 692 |
+
"131792",
|
| 693 |
+
"131919",
|
| 694 |
+
"132155",
|
| 695 |
+
"132207",
|
| 696 |
+
"132282",
|
| 697 |
+
"132296",
|
| 698 |
+
"132371",
|
| 699 |
+
"132589",
|
| 700 |
+
"132597",
|
| 701 |
+
"132605",
|
| 702 |
+
"132812",
|
| 703 |
+
"132896",
|
| 704 |
+
"133196",
|
| 705 |
+
"133338",
|
| 706 |
+
"133340",
|
| 707 |
+
"133562",
|
| 708 |
+
"133710",
|
| 709 |
+
"133814",
|
| 710 |
+
"133850",
|
| 711 |
+
"133934",
|
| 712 |
+
"133946",
|
| 713 |
+
"134032",
|
| 714 |
+
"134197",
|
| 715 |
+
"134555",
|
| 716 |
+
"134654",
|
| 717 |
+
"134919",
|
| 718 |
+
"134955",
|
| 719 |
+
"135467",
|
| 720 |
+
"135503",
|
| 721 |
+
"135628",
|
| 722 |
+
"135697",
|
| 723 |
+
"135725",
|
| 724 |
+
"135830",
|
| 725 |
+
"135855",
|
| 726 |
+
"136104",
|
| 727 |
+
"136105",
|
| 728 |
+
"136144",
|
| 729 |
+
"136175",
|
| 730 |
+
"136220",
|
| 731 |
+
"136310",
|
| 732 |
+
"136382",
|
| 733 |
+
"136589",
|
| 734 |
+
"136966",
|
| 735 |
+
"137168",
|
| 736 |
+
"137508",
|
| 737 |
+
"137617",
|
| 738 |
+
"137624",
|
| 739 |
+
"137675"
|
| 740 |
+
],
|
| 741 |
+
"val_patients": [
|
| 742 |
+
"101228",
|
| 743 |
+
"102035",
|
| 744 |
+
"104252",
|
| 745 |
+
"104670",
|
| 746 |
+
"105917",
|
| 747 |
+
"106270",
|
| 748 |
+
"106536",
|
| 749 |
+
"107130",
|
| 750 |
+
"107539",
|
| 751 |
+
"108295",
|
| 752 |
+
"108975",
|
| 753 |
+
"111140",
|
| 754 |
+
"111189",
|
| 755 |
+
"112659",
|
| 756 |
+
"114058",
|
| 757 |
+
"114128",
|
| 758 |
+
"114266",
|
| 759 |
+
"115588",
|
| 760 |
+
"115991",
|
| 761 |
+
"116937",
|
| 762 |
+
"117385",
|
| 763 |
+
"120857",
|
| 764 |
+
"121620",
|
| 765 |
+
"122316",
|
| 766 |
+
"125465",
|
| 767 |
+
"125567",
|
| 768 |
+
"125798",
|
| 769 |
+
"126396",
|
| 770 |
+
"126465",
|
| 771 |
+
"126494",
|
| 772 |
+
"126704",
|
| 773 |
+
"126779",
|
| 774 |
+
"129055",
|
| 775 |
+
"130556",
|
| 776 |
+
"130662",
|
| 777 |
+
"131606",
|
| 778 |
+
"131636",
|
| 779 |
+
"131924",
|
| 780 |
+
"132920",
|
| 781 |
+
"133886",
|
| 782 |
+
"134728",
|
| 783 |
+
"135687",
|
| 784 |
+
"135695",
|
| 785 |
+
"135733",
|
| 786 |
+
"136793",
|
| 787 |
+
"136817",
|
| 788 |
+
"136996",
|
| 789 |
+
"137104"
|
| 790 |
+
],
|
| 791 |
+
"n_train": 194,
|
| 792 |
+
"n_val": 48
|
| 793 |
+
},
|
| 794 |
+
"fold_3": {
|
| 795 |
+
"train_patients": [
|
| 796 |
+
"101228",
|
| 797 |
+
"101627",
|
| 798 |
+
"102035",
|
| 799 |
+
"104252",
|
| 800 |
+
"104280",
|
| 801 |
+
"104420",
|
| 802 |
+
"104518",
|
| 803 |
+
"104520",
|
| 804 |
+
"104670",
|
| 805 |
+
"104797",
|
| 806 |
+
"104871",
|
| 807 |
+
"104899",
|
| 808 |
+
"104937",
|
| 809 |
+
"105302",
|
| 810 |
+
"105465",
|
| 811 |
+
"105549",
|
| 812 |
+
"105755",
|
| 813 |
+
"105917",
|
| 814 |
+
"106063",
|
| 815 |
+
"106200",
|
| 816 |
+
"106270",
|
| 817 |
+
"106536",
|
| 818 |
+
"106639",
|
| 819 |
+
"106780",
|
| 820 |
+
"106905",
|
| 821 |
+
"106976",
|
| 822 |
+
"107130",
|
| 823 |
+
"107233",
|
| 824 |
+
"107508",
|
| 825 |
+
"107539",
|
| 826 |
+
"107630",
|
| 827 |
+
"107739",
|
| 828 |
+
"108295",
|
| 829 |
+
"108344",
|
| 830 |
+
"108807",
|
| 831 |
+
"108975",
|
| 832 |
+
"109267",
|
| 833 |
+
"109654",
|
| 834 |
+
"109816",
|
| 835 |
+
"109923",
|
| 836 |
+
"110012",
|
| 837 |
+
"110157",
|
| 838 |
+
"110218",
|
| 839 |
+
"110280",
|
| 840 |
+
"110327",
|
| 841 |
+
"110540",
|
| 842 |
+
"111140",
|
| 843 |
+
"111189",
|
| 844 |
+
"111489",
|
| 845 |
+
"111852",
|
| 846 |
+
"112055",
|
| 847 |
+
"112378",
|
| 848 |
+
"112414",
|
| 849 |
+
"112657",
|
| 850 |
+
"112659",
|
| 851 |
+
"112765",
|
| 852 |
+
"113046",
|
| 853 |
+
"113394",
|
| 854 |
+
"113845",
|
| 855 |
+
"114058",
|
| 856 |
+
"114128",
|
| 857 |
+
"114266",
|
| 858 |
+
"114304",
|
| 859 |
+
"114836",
|
| 860 |
+
"114990",
|
| 861 |
+
"115588",
|
| 862 |
+
"115628",
|
| 863 |
+
"115788",
|
| 864 |
+
"115799",
|
| 865 |
+
"115991",
|
| 866 |
+
"116246",
|
| 867 |
+
"116268",
|
| 868 |
+
"116700",
|
| 869 |
+
"116768",
|
| 870 |
+
"116914",
|
| 871 |
+
"116937",
|
| 872 |
+
"117314",
|
| 873 |
+
"117385",
|
| 874 |
+
"117814",
|
| 875 |
+
"118018",
|
| 876 |
+
"118078",
|
| 877 |
+
"118409",
|
| 878 |
+
"118481",
|
| 879 |
+
"118605",
|
| 880 |
+
"118660",
|
| 881 |
+
"118719",
|
| 882 |
+
"118807",
|
| 883 |
+
"119095",
|
| 884 |
+
"119224",
|
| 885 |
+
"119730",
|
| 886 |
+
"120638",
|
| 887 |
+
"120749",
|
| 888 |
+
"120781",
|
| 889 |
+
"120857",
|
| 890 |
+
"121140",
|
| 891 |
+
"121404",
|
| 892 |
+
"121499",
|
| 893 |
+
"121620",
|
| 894 |
+
"121921",
|
| 895 |
+
"122000",
|
| 896 |
+
"122020",
|
| 897 |
+
"122288",
|
| 898 |
+
"122316",
|
| 899 |
+
"122762",
|
| 900 |
+
"122884",
|
| 901 |
+
"124187",
|
| 902 |
+
"125465",
|
| 903 |
+
"125567",
|
| 904 |
+
"125626",
|
| 905 |
+
"125798",
|
| 906 |
+
"126228",
|
| 907 |
+
"126396",
|
| 908 |
+
"126445",
|
| 909 |
+
"126465",
|
| 910 |
+
"126494",
|
| 911 |
+
"126523",
|
| 912 |
+
"126704",
|
| 913 |
+
"126768",
|
| 914 |
+
"126779",
|
| 915 |
+
"127096",
|
| 916 |
+
"127511",
|
| 917 |
+
"127513",
|
| 918 |
+
"127545",
|
| 919 |
+
"127758",
|
| 920 |
+
"127816",
|
| 921 |
+
"127870",
|
| 922 |
+
"128785",
|
| 923 |
+
"128901",
|
| 924 |
+
"129055",
|
| 925 |
+
"129100",
|
| 926 |
+
"129164",
|
| 927 |
+
"129637",
|
| 928 |
+
"129916",
|
| 929 |
+
"130308",
|
| 930 |
+
"130366",
|
| 931 |
+
"130371",
|
| 932 |
+
"130373",
|
| 933 |
+
"130556",
|
| 934 |
+
"130662",
|
| 935 |
+
"130801",
|
| 936 |
+
"131040",
|
| 937 |
+
"131235",
|
| 938 |
+
"131364",
|
| 939 |
+
"131444",
|
| 940 |
+
"131606",
|
| 941 |
+
"131636",
|
| 942 |
+
"131919",
|
| 943 |
+
"131924",
|
| 944 |
+
"132207",
|
| 945 |
+
"132282",
|
| 946 |
+
"132296",
|
| 947 |
+
"132371",
|
| 948 |
+
"132589",
|
| 949 |
+
"132597",
|
| 950 |
+
"132605",
|
| 951 |
+
"132812",
|
| 952 |
+
"132896",
|
| 953 |
+
"132920",
|
| 954 |
+
"133338",
|
| 955 |
+
"133340",
|
| 956 |
+
"133562",
|
| 957 |
+
"133710",
|
| 958 |
+
"133814",
|
| 959 |
+
"133850",
|
| 960 |
+
"133886",
|
| 961 |
+
"133934",
|
| 962 |
+
"133946",
|
| 963 |
+
"134197",
|
| 964 |
+
"134555",
|
| 965 |
+
"134654",
|
| 966 |
+
"134728",
|
| 967 |
+
"134955",
|
| 968 |
+
"135467",
|
| 969 |
+
"135628",
|
| 970 |
+
"135687",
|
| 971 |
+
"135695",
|
| 972 |
+
"135733",
|
| 973 |
+
"135855",
|
| 974 |
+
"136105",
|
| 975 |
+
"136144",
|
| 976 |
+
"136175",
|
| 977 |
+
"136310",
|
| 978 |
+
"136382",
|
| 979 |
+
"136589",
|
| 980 |
+
"136793",
|
| 981 |
+
"136817",
|
| 982 |
+
"136966",
|
| 983 |
+
"136996",
|
| 984 |
+
"137104",
|
| 985 |
+
"137168",
|
| 986 |
+
"137508",
|
| 987 |
+
"137617",
|
| 988 |
+
"137624",
|
| 989 |
+
"137675"
|
| 990 |
+
],
|
| 991 |
+
"val_patients": [
|
| 992 |
+
"102313",
|
| 993 |
+
"104447",
|
| 994 |
+
"104453",
|
| 995 |
+
"104810",
|
| 996 |
+
"105074",
|
| 997 |
+
"105978",
|
| 998 |
+
"107455",
|
| 999 |
+
"107966",
|
| 1000 |
+
"107997",
|
| 1001 |
+
"108444",
|
| 1002 |
+
"108726",
|
| 1003 |
+
"109141",
|
| 1004 |
+
"109395",
|
| 1005 |
+
"109944",
|
| 1006 |
+
"110497",
|
| 1007 |
+
"111691",
|
| 1008 |
+
"112776",
|
| 1009 |
+
"114525",
|
| 1010 |
+
"114585",
|
| 1011 |
+
"114903",
|
| 1012 |
+
"115841",
|
| 1013 |
+
"116236",
|
| 1014 |
+
"116577",
|
| 1015 |
+
"118450",
|
| 1016 |
+
"118755",
|
| 1017 |
+
"121804",
|
| 1018 |
+
"123575",
|
| 1019 |
+
"124899",
|
| 1020 |
+
"125198",
|
| 1021 |
+
"126542",
|
| 1022 |
+
"127897",
|
| 1023 |
+
"129739",
|
| 1024 |
+
"130214",
|
| 1025 |
+
"130282",
|
| 1026 |
+
"130402",
|
| 1027 |
+
"131231",
|
| 1028 |
+
"131494",
|
| 1029 |
+
"131792",
|
| 1030 |
+
"132155",
|
| 1031 |
+
"133196",
|
| 1032 |
+
"134032",
|
| 1033 |
+
"134919",
|
| 1034 |
+
"135503",
|
| 1035 |
+
"135697",
|
| 1036 |
+
"135725",
|
| 1037 |
+
"135830",
|
| 1038 |
+
"136104",
|
| 1039 |
+
"136220"
|
| 1040 |
+
],
|
| 1041 |
+
"n_train": 194,
|
| 1042 |
+
"n_val": 48
|
| 1043 |
+
},
|
| 1044 |
+
"fold_4": {
|
| 1045 |
+
"train_patients": [
|
| 1046 |
+
"101228",
|
| 1047 |
+
"102035",
|
| 1048 |
+
"102313",
|
| 1049 |
+
"104252",
|
| 1050 |
+
"104280",
|
| 1051 |
+
"104420",
|
| 1052 |
+
"104447",
|
| 1053 |
+
"104453",
|
| 1054 |
+
"104518",
|
| 1055 |
+
"104520",
|
| 1056 |
+
"104670",
|
| 1057 |
+
"104797",
|
| 1058 |
+
"104810",
|
| 1059 |
+
"104899",
|
| 1060 |
+
"104937",
|
| 1061 |
+
"105074",
|
| 1062 |
+
"105302",
|
| 1063 |
+
"105465",
|
| 1064 |
+
"105917",
|
| 1065 |
+
"105978",
|
| 1066 |
+
"106063",
|
| 1067 |
+
"106200",
|
| 1068 |
+
"106270",
|
| 1069 |
+
"106536",
|
| 1070 |
+
"106639",
|
| 1071 |
+
"106780",
|
| 1072 |
+
"106905",
|
| 1073 |
+
"106976",
|
| 1074 |
+
"107130",
|
| 1075 |
+
"107233",
|
| 1076 |
+
"107455",
|
| 1077 |
+
"107508",
|
| 1078 |
+
"107539",
|
| 1079 |
+
"107739",
|
| 1080 |
+
"107966",
|
| 1081 |
+
"107997",
|
| 1082 |
+
"108295",
|
| 1083 |
+
"108344",
|
| 1084 |
+
"108444",
|
| 1085 |
+
"108726",
|
| 1086 |
+
"108807",
|
| 1087 |
+
"108975",
|
| 1088 |
+
"109141",
|
| 1089 |
+
"109395",
|
| 1090 |
+
"109816",
|
| 1091 |
+
"109944",
|
| 1092 |
+
"110157",
|
| 1093 |
+
"110218",
|
| 1094 |
+
"110497",
|
| 1095 |
+
"110540",
|
| 1096 |
+
"111140",
|
| 1097 |
+
"111189",
|
| 1098 |
+
"111691",
|
| 1099 |
+
"111852",
|
| 1100 |
+
"112055",
|
| 1101 |
+
"112378",
|
| 1102 |
+
"112414",
|
| 1103 |
+
"112657",
|
| 1104 |
+
"112659",
|
| 1105 |
+
"112776",
|
| 1106 |
+
"113046",
|
| 1107 |
+
"113845",
|
| 1108 |
+
"114058",
|
| 1109 |
+
"114128",
|
| 1110 |
+
"114266",
|
| 1111 |
+
"114304",
|
| 1112 |
+
"114525",
|
| 1113 |
+
"114585",
|
| 1114 |
+
"114836",
|
| 1115 |
+
"114903",
|
| 1116 |
+
"114990",
|
| 1117 |
+
"115588",
|
| 1118 |
+
"115628",
|
| 1119 |
+
"115841",
|
| 1120 |
+
"115991",
|
| 1121 |
+
"116236",
|
| 1122 |
+
"116268",
|
| 1123 |
+
"116577",
|
| 1124 |
+
"116700",
|
| 1125 |
+
"116768",
|
| 1126 |
+
"116914",
|
| 1127 |
+
"116937",
|
| 1128 |
+
"117385",
|
| 1129 |
+
"117814",
|
| 1130 |
+
"118409",
|
| 1131 |
+
"118450",
|
| 1132 |
+
"118660",
|
| 1133 |
+
"118719",
|
| 1134 |
+
"118755",
|
| 1135 |
+
"118807",
|
| 1136 |
+
"119095",
|
| 1137 |
+
"119224",
|
| 1138 |
+
"119730",
|
| 1139 |
+
"120638",
|
| 1140 |
+
"120781",
|
| 1141 |
+
"120857",
|
| 1142 |
+
"121140",
|
| 1143 |
+
"121404",
|
| 1144 |
+
"121620",
|
| 1145 |
+
"121804",
|
| 1146 |
+
"121921",
|
| 1147 |
+
"122000",
|
| 1148 |
+
"122020",
|
| 1149 |
+
"122288",
|
| 1150 |
+
"122316",
|
| 1151 |
+
"123575",
|
| 1152 |
+
"124187",
|
| 1153 |
+
"124899",
|
| 1154 |
+
"125198",
|
| 1155 |
+
"125465",
|
| 1156 |
+
"125567",
|
| 1157 |
+
"125626",
|
| 1158 |
+
"125798",
|
| 1159 |
+
"126228",
|
| 1160 |
+
"126396",
|
| 1161 |
+
"126445",
|
| 1162 |
+
"126465",
|
| 1163 |
+
"126494",
|
| 1164 |
+
"126523",
|
| 1165 |
+
"126542",
|
| 1166 |
+
"126704",
|
| 1167 |
+
"126768",
|
| 1168 |
+
"126779",
|
| 1169 |
+
"127511",
|
| 1170 |
+
"127545",
|
| 1171 |
+
"127758",
|
| 1172 |
+
"127816",
|
| 1173 |
+
"127870",
|
| 1174 |
+
"127897",
|
| 1175 |
+
"128901",
|
| 1176 |
+
"129055",
|
| 1177 |
+
"129164",
|
| 1178 |
+
"129637",
|
| 1179 |
+
"129739",
|
| 1180 |
+
"129916",
|
| 1181 |
+
"130214",
|
| 1182 |
+
"130282",
|
| 1183 |
+
"130308",
|
| 1184 |
+
"130366",
|
| 1185 |
+
"130373",
|
| 1186 |
+
"130402",
|
| 1187 |
+
"130556",
|
| 1188 |
+
"130662",
|
| 1189 |
+
"131040",
|
| 1190 |
+
"131231",
|
| 1191 |
+
"131235",
|
| 1192 |
+
"131364",
|
| 1193 |
+
"131494",
|
| 1194 |
+
"131606",
|
| 1195 |
+
"131636",
|
| 1196 |
+
"131792",
|
| 1197 |
+
"131919",
|
| 1198 |
+
"131924",
|
| 1199 |
+
"132155",
|
| 1200 |
+
"132371",
|
| 1201 |
+
"132597",
|
| 1202 |
+
"132605",
|
| 1203 |
+
"132812",
|
| 1204 |
+
"132896",
|
| 1205 |
+
"132920",
|
| 1206 |
+
"133196",
|
| 1207 |
+
"133338",
|
| 1208 |
+
"133340",
|
| 1209 |
+
"133886",
|
| 1210 |
+
"133934",
|
| 1211 |
+
"133946",
|
| 1212 |
+
"134032",
|
| 1213 |
+
"134197",
|
| 1214 |
+
"134555",
|
| 1215 |
+
"134728",
|
| 1216 |
+
"134919",
|
| 1217 |
+
"135503",
|
| 1218 |
+
"135628",
|
| 1219 |
+
"135687",
|
| 1220 |
+
"135695",
|
| 1221 |
+
"135697",
|
| 1222 |
+
"135725",
|
| 1223 |
+
"135733",
|
| 1224 |
+
"135830",
|
| 1225 |
+
"135855",
|
| 1226 |
+
"136104",
|
| 1227 |
+
"136105",
|
| 1228 |
+
"136144",
|
| 1229 |
+
"136220",
|
| 1230 |
+
"136310",
|
| 1231 |
+
"136382",
|
| 1232 |
+
"136589",
|
| 1233 |
+
"136793",
|
| 1234 |
+
"136817",
|
| 1235 |
+
"136996",
|
| 1236 |
+
"137104",
|
| 1237 |
+
"137168",
|
| 1238 |
+
"137617",
|
| 1239 |
+
"137624"
|
| 1240 |
+
],
|
| 1241 |
+
"val_patients": [
|
| 1242 |
+
"101627",
|
| 1243 |
+
"104871",
|
| 1244 |
+
"105549",
|
| 1245 |
+
"105755",
|
| 1246 |
+
"107630",
|
| 1247 |
+
"109267",
|
| 1248 |
+
"109654",
|
| 1249 |
+
"109923",
|
| 1250 |
+
"110012",
|
| 1251 |
+
"110280",
|
| 1252 |
+
"110327",
|
| 1253 |
+
"111489",
|
| 1254 |
+
"112765",
|
| 1255 |
+
"113394",
|
| 1256 |
+
"115788",
|
| 1257 |
+
"115799",
|
| 1258 |
+
"116246",
|
| 1259 |
+
"117314",
|
| 1260 |
+
"118018",
|
| 1261 |
+
"118078",
|
| 1262 |
+
"118481",
|
| 1263 |
+
"118605",
|
| 1264 |
+
"120749",
|
| 1265 |
+
"121499",
|
| 1266 |
+
"122762",
|
| 1267 |
+
"122884",
|
| 1268 |
+
"127096",
|
| 1269 |
+
"127513",
|
| 1270 |
+
"128785",
|
| 1271 |
+
"129100",
|
| 1272 |
+
"130371",
|
| 1273 |
+
"130801",
|
| 1274 |
+
"131444",
|
| 1275 |
+
"132207",
|
| 1276 |
+
"132282",
|
| 1277 |
+
"132296",
|
| 1278 |
+
"132589",
|
| 1279 |
+
"133562",
|
| 1280 |
+
"133710",
|
| 1281 |
+
"133814",
|
| 1282 |
+
"133850",
|
| 1283 |
+
"134654",
|
| 1284 |
+
"134955",
|
| 1285 |
+
"135467",
|
| 1286 |
+
"136175",
|
| 1287 |
+
"136966",
|
| 1288 |
+
"137508",
|
| 1289 |
+
"137675"
|
| 1290 |
+
],
|
| 1291 |
+
"n_train": 194,
|
| 1292 |
+
"n_val": 48
|
| 1293 |
+
}
|
| 1294 |
+
}
|
| 1295 |
+
}
|
models/for_GM/model_training_scripts/p1_compute_class_weights.py
ADDED
|
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
P1 Article - Compute Class Weights from Training Data
|
| 3 |
+
Utility script to calculate inverse frequency weights for class balancing
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
python p1_compute_class_weights.py --fold 0 --scenario binary --preprocessing standard
|
| 7 |
+
|
| 8 |
+
Output:
|
| 9 |
+
Saves class weights to JSON file for reproducibility
|
| 10 |
+
Prints weights for use in training
|
| 11 |
+
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
import json
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
import argparse
|
| 19 |
+
|
| 20 |
+
# Import data loader
|
| 21 |
+
from p1_data_loader import DataConfig, P1DataLoader
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def compute_class_frequencies(dataset, num_classes, total_samples=None):
|
| 25 |
+
"""
|
| 26 |
+
Compute class frequencies from dataset
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
dataset: TensorFlow dataset yielding (paired_input, target_mask)
|
| 30 |
+
num_classes: Number of classes (2)
|
| 31 |
+
total_samples: Total number of samples (for progress bar)
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
class_pixel_counts: Array of pixel counts per class
|
| 35 |
+
total_pixels: Total number of pixels analyzed
|
| 36 |
+
"""
|
| 37 |
+
class_pixel_counts = np.zeros(num_classes, dtype=np.int64)
|
| 38 |
+
total_pixels = 0
|
| 39 |
+
|
| 40 |
+
print(f"Computing class frequencies for {num_classes}-class scenario...")
|
| 41 |
+
|
| 42 |
+
iterator = tqdm(dataset, total=total_samples, desc="Processing") if total_samples else dataset
|
| 43 |
+
|
| 44 |
+
for paired_input, target_mask, _, _ in iterator:
|
| 45 |
+
# target_mask shape: (batch_size, 256, 256)
|
| 46 |
+
masks = target_mask.numpy()
|
| 47 |
+
|
| 48 |
+
for mask in masks:
|
| 49 |
+
# Count pixels for each class
|
| 50 |
+
for class_id in range(num_classes):
|
| 51 |
+
class_pixel_counts[class_id] += np.sum(mask == class_id)
|
| 52 |
+
|
| 53 |
+
total_pixels += mask.size
|
| 54 |
+
|
| 55 |
+
return class_pixel_counts, total_pixels
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def compute_inverse_frequency_weights(class_pixel_counts, num_classes):
|
| 59 |
+
"""
|
| 60 |
+
Compute inverse frequency weights with normalization
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
class_pixel_counts: Array of pixel counts per class
|
| 64 |
+
num_classes: Number of classes
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
class_weights: Normalized inverse frequency weights
|
| 68 |
+
class_frequencies: Class frequencies (for reference)
|
| 69 |
+
"""
|
| 70 |
+
total_pixels = np.sum(class_pixel_counts)
|
| 71 |
+
|
| 72 |
+
# Class frequencies
|
| 73 |
+
class_frequencies = class_pixel_counts / total_pixels
|
| 74 |
+
|
| 75 |
+
# Inverse frequency (with small epsilon to avoid division by zero)
|
| 76 |
+
epsilon = 1e-6
|
| 77 |
+
inverse_freq = 1.0 / (class_frequencies + epsilon)
|
| 78 |
+
|
| 79 |
+
# Normalize weights to sum = num_classes
|
| 80 |
+
# This keeps weights in a reasonable range while maintaining relative importance
|
| 81 |
+
class_weights = inverse_freq / np.sum(inverse_freq) * num_classes
|
| 82 |
+
|
| 83 |
+
return class_weights, class_frequencies
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def compute_and_save_class_weights(fold_id, class_scenario, preprocessing,
|
| 87 |
+
output_dir='class_weights_gm'):
|
| 88 |
+
"""
|
| 89 |
+
Compute class weights for a specific fold and scenario
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
fold_id: Fold number (0-4)
|
| 93 |
+
class_scenario: 'binary'
|
| 94 |
+
preprocessing: 'standard' or 'zoomed'
|
| 95 |
+
output_dir: Directory to save weights
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
Dictionary with weights and statistics
|
| 99 |
+
"""
|
| 100 |
+
print("\n" + "="*70)
|
| 101 |
+
print(f"COMPUTING CLASS WEIGHTS")
|
| 102 |
+
print("="*70)
|
| 103 |
+
print(f"Fold: {fold_id}")
|
| 104 |
+
print(f"Scenario: {class_scenario}")
|
| 105 |
+
print(f"Preprocessing: {preprocessing}")
|
| 106 |
+
print("="*70 + "\n")
|
| 107 |
+
|
| 108 |
+
# Initialize data loader
|
| 109 |
+
config = DataConfig()
|
| 110 |
+
data_loader = P1DataLoader(config)
|
| 111 |
+
|
| 112 |
+
# Determine number of classes
|
| 113 |
+
num_classes = 2 if class_scenario == 'binary' else 2
|
| 114 |
+
|
| 115 |
+
# Load training dataset
|
| 116 |
+
print("Loading training dataset...")
|
| 117 |
+
train_dataset = data_loader.create_dataset_for_fold(
|
| 118 |
+
fold_id=fold_id,
|
| 119 |
+
split='train',
|
| 120 |
+
preprocessing=preprocessing,
|
| 121 |
+
class_scenario=class_scenario,
|
| 122 |
+
batch_size=4, # Larger batch for faster processing
|
| 123 |
+
shuffle=False # No need to shuffle for counting
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
# Get dataset size
|
| 127 |
+
train_size = sum(1 for _ in train_dataset)
|
| 128 |
+
print(f"Training samples: {train_size}")
|
| 129 |
+
|
| 130 |
+
# Compute class frequencies
|
| 131 |
+
class_pixel_counts, total_pixels = compute_class_frequencies(
|
| 132 |
+
train_dataset, num_classes, train_size
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
# Compute inverse frequency weights
|
| 136 |
+
class_weights, class_frequencies = compute_inverse_frequency_weights(
|
| 137 |
+
class_pixel_counts, num_classes
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
# Print results
|
| 141 |
+
print("\n" + "="*70)
|
| 142 |
+
print("RESULTS")
|
| 143 |
+
print("="*70)
|
| 144 |
+
|
| 145 |
+
class_names = {
|
| 146 |
+
2: ['Background', 'Specialized GM']
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
print(f"\nTotal pixels analyzed: {total_pixels:,}")
|
| 150 |
+
print(f"\nClass Statistics:")
|
| 151 |
+
print("-" * 70)
|
| 152 |
+
|
| 153 |
+
for i in range(num_classes):
|
| 154 |
+
print(f"Class {i} ({class_names[num_classes][i]}):")
|
| 155 |
+
print(f" Pixel count: {class_pixel_counts[i]:,}")
|
| 156 |
+
print(f" Frequency: {class_frequencies[i]:.6f} ({class_frequencies[i]*100:.2f}%)")
|
| 157 |
+
print(f" Weight: {class_weights[i]:.4f}")
|
| 158 |
+
print()
|
| 159 |
+
|
| 160 |
+
# Save to JSON
|
| 161 |
+
output_path = Path(output_dir)
|
| 162 |
+
output_path.mkdir(exist_ok=True)
|
| 163 |
+
|
| 164 |
+
results = {
|
| 165 |
+
'fold_id': fold_id,
|
| 166 |
+
'class_scenario': class_scenario,
|
| 167 |
+
'preprocessing': preprocessing,
|
| 168 |
+
'num_classes': num_classes,
|
| 169 |
+
'total_pixels': int(total_pixels),
|
| 170 |
+
'class_pixel_counts': class_pixel_counts.tolist(),
|
| 171 |
+
'class_frequencies': class_frequencies.tolist(),
|
| 172 |
+
'class_weights': class_weights.tolist(),
|
| 173 |
+
'class_names': class_names[num_classes]
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
filename = f"class_weights_fold{fold_id}_{preprocessing}_{class_scenario}.json"
|
| 177 |
+
filepath = output_path / filename
|
| 178 |
+
|
| 179 |
+
with open(filepath, 'w') as f:
|
| 180 |
+
json.dump(results, f, indent=2)
|
| 181 |
+
|
| 182 |
+
print("="*70)
|
| 183 |
+
print(f"✅ Class weights saved to: {filepath}")
|
| 184 |
+
print("="*70)
|
| 185 |
+
|
| 186 |
+
# Print weights in format ready for code
|
| 187 |
+
print("\nFor use in training script:")
|
| 188 |
+
print("-" * 70)
|
| 189 |
+
print(f"class_weights = tf.constant({class_weights.tolist()}, dtype=tf.float32)")
|
| 190 |
+
print()
|
| 191 |
+
|
| 192 |
+
return results
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def compute_all_scenarios_for_fold(fold_id):
|
| 196 |
+
"""
|
| 197 |
+
Compute class weights for all 2 scenarios of a given fold
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
fold_id: Fold number (0-4)
|
| 201 |
+
"""
|
| 202 |
+
scenarios = [
|
| 203 |
+
{'preprocessing': 'standard', 'class_scenario': 'binary'},
|
| 204 |
+
{'preprocessing': 'zoomed', 'class_scenario': 'binary'},
|
| 205 |
+
]
|
| 206 |
+
|
| 207 |
+
all_results = {}
|
| 208 |
+
|
| 209 |
+
for scenario in scenarios:
|
| 210 |
+
results = compute_and_save_class_weights(
|
| 211 |
+
fold_id=fold_id,
|
| 212 |
+
class_scenario=scenario['class_scenario'],
|
| 213 |
+
preprocessing=scenario['preprocessing']
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
key = f"{scenario['preprocessing']}_{scenario['class_scenario']}"
|
| 217 |
+
all_results[key] = results
|
| 218 |
+
|
| 219 |
+
print("\n" + "="*70 + "\n")
|
| 220 |
+
|
| 221 |
+
return all_results
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def load_class_weights(fold_id, class_scenario, preprocessing, weights_dir='class_weights_gm'):
|
| 225 |
+
"""
|
| 226 |
+
Load previously computed class weights
|
| 227 |
+
|
| 228 |
+
Args:
|
| 229 |
+
fold_id: Fold number (0-4)
|
| 230 |
+
class_scenario: 'binary'
|
| 231 |
+
preprocessing: 'standard' or 'zoomed'
|
| 232 |
+
weights_dir: Directory containing weights files
|
| 233 |
+
|
| 234 |
+
Returns:
|
| 235 |
+
class_weights: NumPy array of weights
|
| 236 |
+
"""
|
| 237 |
+
weights_path = Path(weights_dir)
|
| 238 |
+
filename = f"class_weights_fold{fold_id}_{preprocessing}_{class_scenario}.json"
|
| 239 |
+
filepath = weights_path / filename
|
| 240 |
+
|
| 241 |
+
if not filepath.exists():
|
| 242 |
+
raise FileNotFoundError(
|
| 243 |
+
f"Class weights not found: {filepath}\n"
|
| 244 |
+
f"Run compute_and_save_class_weights() first."
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
with open(filepath, 'r') as f:
|
| 248 |
+
results = json.load(f)
|
| 249 |
+
|
| 250 |
+
class_weights = np.array(results['class_weights'], dtype=np.float32)
|
| 251 |
+
|
| 252 |
+
return class_weights
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def main():
|
| 256 |
+
"""Main entry point with argument parsing"""
|
| 257 |
+
parser = argparse.ArgumentParser(
|
| 258 |
+
description='Compute class weights from training data',
|
| 259 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 260 |
+
epilog="""
|
| 261 |
+
Examples:
|
| 262 |
+
# Single scenario
|
| 263 |
+
python p1_compute_class_weights.py --fold 0 --scenario binary --preprocessing standard
|
| 264 |
+
|
| 265 |
+
# All scenarios for one fold
|
| 266 |
+
python p1_compute_class_weights.py --fold 0 --all
|
| 267 |
+
|
| 268 |
+
# All folds (for completeness)
|
| 269 |
+
python p1_compute_class_weights.py --all-folds
|
| 270 |
+
"""
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
parser.add_argument(
|
| 274 |
+
'--fold',
|
| 275 |
+
type=int,
|
| 276 |
+
choices=[0, 1, 2, 3, 4],
|
| 277 |
+
help='Fold number (0-4)'
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
parser.add_argument(
|
| 281 |
+
'--scenario',
|
| 282 |
+
type=str,
|
| 283 |
+
choices=['binary'],
|
| 284 |
+
help='Class scenario'
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
parser.add_argument(
|
| 288 |
+
'--preprocessing',
|
| 289 |
+
type=str,
|
| 290 |
+
choices=['standard', 'zoomed'],
|
| 291 |
+
help='Preprocessing type'
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
parser.add_argument(
|
| 295 |
+
'--all',
|
| 296 |
+
action='store_true',
|
| 297 |
+
help='Compute for all scenarios of specified fold'
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
parser.add_argument(
|
| 301 |
+
'--all-folds',
|
| 302 |
+
action='store_true',
|
| 303 |
+
help='Compute for all scenarios of all folds'
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
args = parser.parse_args()
|
| 307 |
+
|
| 308 |
+
# Validate arguments
|
| 309 |
+
if args.all_folds:
|
| 310 |
+
# Compute for all folds
|
| 311 |
+
for fold_id in range(5):
|
| 312 |
+
print(f"\n{'='*70}")
|
| 313 |
+
print(f"PROCESSING FOLD {fold_id}")
|
| 314 |
+
print(f"{'='*70}\n")
|
| 315 |
+
compute_all_scenarios_for_fold(fold_id)
|
| 316 |
+
|
| 317 |
+
elif args.all:
|
| 318 |
+
# Compute all scenarios for one fold
|
| 319 |
+
if args.fold is None:
|
| 320 |
+
parser.error("--fold is required when using --all")
|
| 321 |
+
compute_all_scenarios_for_fold(args.fold)
|
| 322 |
+
|
| 323 |
+
else:
|
| 324 |
+
# Compute single scenario
|
| 325 |
+
if args.fold is None or args.scenario is None or args.preprocessing is None:
|
| 326 |
+
parser.error("--fold, --scenario, and --preprocessing are required")
|
| 327 |
+
|
| 328 |
+
compute_and_save_class_weights(
|
| 329 |
+
fold_id=args.fold,
|
| 330 |
+
class_scenario=args.scenario,
|
| 331 |
+
preprocessing=args.preprocessing
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
if __name__ == "__main__":
|
| 336 |
+
main()
|
models/for_GM/model_training_scripts/p1_data_loader.py
ADDED
|
@@ -0,0 +1,847 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
P1 & P4 Articles - Data Loading System
|
| 3 |
+
|
| 4 |
+
Complete implementation for brain segmentation experiments
|
| 5 |
+
|
| 6 |
+
Specialized Gray Matter (GM) Segmentation with U-Net Models - Journal Paper Implementation
|
| 7 |
+
Binary segmentation: Background vs Specialized GM
|
| 8 |
+
Professional results saving and visualization for publication
|
| 9 |
+
|
| 10 |
+
This relates to our articles:
|
| 11 |
+
"Specialized gray matter segmentation via a generative adversarial network:
|
| 12 |
+
application on brain white matter hyperintensities classification"
|
| 13 |
+
|
| 14 |
+
"Deep Learning-Based Neuroanatomical Profiling Reveals Detailed Brain Changes:
|
| 15 |
+
A Large-Scale Multiple Sclerosis Study"
|
| 16 |
+
|
| 17 |
+
Features:
|
| 18 |
+
- Load FLAIR images and individual mask files from Cohort directory
|
| 19 |
+
- Support both Local_SAI_GM_sp dataset
|
| 20 |
+
- Handle standard and zoomed preprocessing variants
|
| 21 |
+
- Combine masks into 2-class format
|
| 22 |
+
- Create paired inputs: [FLAIR | mask] concatenated (256x512)
|
| 23 |
+
- Patient-stratified K-fold cross-validation
|
| 24 |
+
- TensorFlow dataset creation with proper batching
|
| 25 |
+
|
| 26 |
+
Authors:
|
| 27 |
+
"Mahdi Bashiri Bawil, Mousa Shamsi, Abolhassan Shakeri Bavil"
|
| 28 |
+
|
| 29 |
+
Developer:
|
| 30 |
+
"Mahdi Bashiri Bawil"
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
import numpy as np
|
| 34 |
+
import os
|
| 35 |
+
from pathlib import Path
|
| 36 |
+
from typing import Tuple, List, Dict, Optional
|
| 37 |
+
import json
|
| 38 |
+
from sklearn.model_selection import KFold
|
| 39 |
+
from tqdm import tqdm
|
| 40 |
+
import cv2 as cv
|
| 41 |
+
|
| 42 |
+
# Deep Learning
|
| 43 |
+
import tensorflow as tf
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
###################### Configuration ######################
|
| 47 |
+
|
| 48 |
+
class DataConfig:
|
| 49 |
+
"""Data configuration for P4 experiments"""
|
| 50 |
+
|
| 51 |
+
def __init__(self):
|
| 52 |
+
# Base paths
|
| 53 |
+
self.cohort_dir = Path("/mnt/e/MBashiri/ours_articles/Paper#2/Data/Cohort") # CHANGE THIS to your actual path of Data Cohort
|
| 54 |
+
|
| 55 |
+
# Dataset configurations
|
| 56 |
+
self.datasets = {
|
| 57 |
+
'Local_SAI_GM_sp': {
|
| 58 |
+
'base_path': self.cohort_dir / 'Local_SAI_GM_sp',
|
| 59 |
+
'slice_range': (1, 20), # inclusive range 9,15
|
| 60 |
+
'patient_prefix_length': 6 # "101228"
|
| 61 |
+
}
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
# Preprocessing variants
|
| 65 |
+
self.preprocessing_types = ['standard', 'zoomed']
|
| 66 |
+
|
| 67 |
+
# Class scenarios
|
| 68 |
+
self.class_scenarios = {
|
| 69 |
+
'binary': {
|
| 70 |
+
'num_classes': 2,
|
| 71 |
+
'class_names': ['Background', 'Specialized GM'],
|
| 72 |
+
'description': 'Binary: Background, Specialized GM',
|
| 73 |
+
'class_mapping': {
|
| 74 |
+
'background': 0,
|
| 75 |
+
'specialized_gm': 1,
|
| 76 |
+
}
|
| 77 |
+
}
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
# K-fold parameters
|
| 81 |
+
self.k_folds = 5
|
| 82 |
+
self.test_split = 0.1 # 10% for test set
|
| 83 |
+
self.random_state = 42
|
| 84 |
+
|
| 85 |
+
# Image parameters
|
| 86 |
+
self.target_size = (256, 256)
|
| 87 |
+
self.paired_width = 512 # FLAIR (256) + mask (256)
|
| 88 |
+
|
| 89 |
+
# Paths for splits
|
| 90 |
+
self.splits_dir = Path("data_splits_sp_gm")
|
| 91 |
+
self.splits_file = self.splits_dir / "SP_GM_fold_assignments.json"
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
###################### Helper Functions ######################
|
| 95 |
+
|
| 96 |
+
def extract_patient_id(filename: str, prefix_length: int = 6) -> str:
|
| 97 |
+
"""
|
| 98 |
+
Extract patient ID from filename
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
filename: e.g., "101228_5.npy" or "c01p01_25.png"
|
| 102 |
+
prefix_length: Number of characters in patient ID
|
| 103 |
+
|
| 104 |
+
Returns:
|
| 105 |
+
Patient ID: e.g., "101228" or "c01p01"
|
| 106 |
+
"""
|
| 107 |
+
return filename.split('_')[0][:prefix_length]
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def extract_slice_number(filename: str) -> int:
|
| 111 |
+
"""
|
| 112 |
+
Extract slice number from filename
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
filename: e.g., "101228_5.npy" or "c01p01_25.png"
|
| 116 |
+
|
| 117 |
+
Returns:
|
| 118 |
+
Slice number as integer
|
| 119 |
+
"""
|
| 120 |
+
# Get the part before file extension
|
| 121 |
+
basename = filename.split('.')[0]
|
| 122 |
+
# Get the last part after splitting by '_'
|
| 123 |
+
slice_num = basename.split('_')[-1]
|
| 124 |
+
return int(slice_num)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def load_flair_image(flair_path: Path, normalize: bool = False, of_z_score: bool = False) -> np.ndarray:
|
| 128 |
+
"""
|
| 129 |
+
Load FLAIR image (.png format)
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
flair_path: Path to .png file
|
| 133 |
+
normalize: Whether to apply z-score normalization
|
| 134 |
+
|
| 135 |
+
Returns:
|
| 136 |
+
FLAIR image (256, 256, 1) as float32
|
| 137 |
+
"""
|
| 138 |
+
if of_z_score:
|
| 139 |
+
# Load NPY: the already z-scored FLAIR image data
|
| 140 |
+
flair = np.load(str(flair_path).replace('.png','.npy')).astype(np.float32)
|
| 141 |
+
else:
|
| 142 |
+
# Load PNG as grayscale
|
| 143 |
+
flair = cv.imread(str(flair_path), cv.IMREAD_GRAYSCALE).astype(np.float32)
|
| 144 |
+
|
| 145 |
+
# Normalize to [-1, 1]:
|
| 146 |
+
flair = (flair - np.min(flair)) / (np.max(flair) - np.min(flair))
|
| 147 |
+
flair = (2 * flair) - 1
|
| 148 |
+
|
| 149 |
+
# Ensure correct shape
|
| 150 |
+
if len(flair.shape) == 2:
|
| 151 |
+
flair = np.expand_dims(flair, axis=-1)
|
| 152 |
+
|
| 153 |
+
# Additional normalization if needed (should already be normalized)
|
| 154 |
+
if normalize and (np.std(flair) > 2.0 or np.abs(np.mean(flair)) > 1.0):
|
| 155 |
+
# Re-normalize if values seem off
|
| 156 |
+
flair = (flair - np.mean(flair)) / (np.std(flair) + 1e-7)
|
| 157 |
+
|
| 158 |
+
return flair
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def load_mask_image(mask_path: Path) -> np.ndarray:
|
| 162 |
+
"""
|
| 163 |
+
Load mask image (.png format)
|
| 164 |
+
|
| 165 |
+
Args:
|
| 166 |
+
mask_path: Path to .png file
|
| 167 |
+
|
| 168 |
+
Returns:
|
| 169 |
+
Binary mask (256, 256) as uint8
|
| 170 |
+
"""
|
| 171 |
+
# Load PNG as grayscale
|
| 172 |
+
mask = cv.imread(str(mask_path), cv.IMREAD_GRAYSCALE)
|
| 173 |
+
|
| 174 |
+
if mask is None:
|
| 175 |
+
raise FileNotFoundError(f"Could not load mask: {mask_path}")
|
| 176 |
+
|
| 177 |
+
# Binarize (any non-zero value becomes 1)
|
| 178 |
+
mask = (mask > 0).astype(np.uint8)
|
| 179 |
+
|
| 180 |
+
return mask
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def combine_masks(gm_mask: np.ndarray,
|
| 184 |
+
class_scenario: str,
|
| 185 |
+
preprocess: bool = False) -> np.ndarray:
|
| 186 |
+
"""
|
| 187 |
+
Combine individual masks into multi-class format
|
| 188 |
+
|
| 189 |
+
Args:
|
| 190 |
+
gm_mask: Ventricles mask (256, 256)
|
| 191 |
+
class_scenario: 'binary'
|
| 192 |
+
preprocess: Boolean turning the morphological preprocessing on or off
|
| 193 |
+
|
| 194 |
+
Returns:
|
| 195 |
+
Combined mask (256, 256) with class labels
|
| 196 |
+
"""
|
| 197 |
+
if preprocess:
|
| 198 |
+
from skimage.morphology import remove_small_objects, binary_erosion, binary_closing, binary_opening, disk, binary_dilation
|
| 199 |
+
min_object_size = 5
|
| 200 |
+
closing_kernel_size = 2
|
| 201 |
+
dilation_kernel_size = 1
|
| 202 |
+
|
| 203 |
+
gm_mask = gm_mask > 0
|
| 204 |
+
|
| 205 |
+
gm_mask = binary_closing(gm_mask, disk(closing_kernel_size))
|
| 206 |
+
gm_mask = binary_erosion(gm_mask, disk(dilation_kernel_size))
|
| 207 |
+
gm_mask = remove_small_objects(gm_mask, min_size=min_object_size)
|
| 208 |
+
|
| 209 |
+
# Class 0: Background (default)
|
| 210 |
+
# Class 1: Specialized GM
|
| 211 |
+
combined = np.zeros_like(gm_mask, dtype=np.uint8)
|
| 212 |
+
combined[gm_mask>0] = 1
|
| 213 |
+
|
| 214 |
+
return combined
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def is_valid_slice(gm_mask: np.ndarray) -> bool:
|
| 218 |
+
"""
|
| 219 |
+
Check if slice has at least one non-empty mask
|
| 220 |
+
|
| 221 |
+
Args:
|
| 222 |
+
gm_mask: Specialized GM mask (256, 256)
|
| 223 |
+
|
| 224 |
+
Returns:
|
| 225 |
+
True if at least one mask has non-zero pixels
|
| 226 |
+
"""
|
| 227 |
+
has_specialized_gm = np.sum(gm_mask) > 50
|
| 228 |
+
|
| 229 |
+
# Valid if ANY mask has content
|
| 230 |
+
return True # or has_specialized_gm
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def create_paired_input(flair: np.ndarray,
|
| 234 |
+
mask: np.ndarray,
|
| 235 |
+
brain_mask: np.ndarray,
|
| 236 |
+
num_classes: np.ndarray,
|
| 237 |
+
if_bet=False) -> np.ndarray:
|
| 238 |
+
"""
|
| 239 |
+
Create paired input: [FLAIR | mask] concatenated horizontally
|
| 240 |
+
|
| 241 |
+
Args:
|
| 242 |
+
flair: FLAIR image (256, 256, 1) float32
|
| 243 |
+
mask: Combined mask (256, 256) uint8
|
| 244 |
+
|
| 245 |
+
Returns:
|
| 246 |
+
Paired image (256, 512, 1) float32
|
| 247 |
+
"""
|
| 248 |
+
# Binarize (any non-zero value becomes 1)
|
| 249 |
+
brain_mask = brain_mask > 0
|
| 250 |
+
|
| 251 |
+
# Brain extraction
|
| 252 |
+
if if_bet:
|
| 253 |
+
# print("\n\t Doing THEEEEEEEEE BET")
|
| 254 |
+
flair[~brain_mask] = np.min(flair)
|
| 255 |
+
mask[~brain_mask] = 0
|
| 256 |
+
|
| 257 |
+
# Ensure flair is 3D
|
| 258 |
+
if len(flair.shape) == 2:
|
| 259 |
+
flair = np.expand_dims(flair, axis=-1)
|
| 260 |
+
|
| 261 |
+
# Convert mask to float and normalize to [0, 1] range for consistency
|
| 262 |
+
|
| 263 |
+
max_class = num_classes
|
| 264 |
+
mask_normalized = mask.astype(np.float32)
|
| 265 |
+
if max_class > 0:
|
| 266 |
+
mask_normalized = mask_normalized / max_class
|
| 267 |
+
mask_normalized = (2 * mask_normalized) - 1
|
| 268 |
+
|
| 269 |
+
mask_3d = np.expand_dims(mask_normalized, axis=-1)
|
| 270 |
+
|
| 271 |
+
# Concatenate horizontally: [FLAIR | mask]
|
| 272 |
+
paired = np.concatenate([flair, mask_3d], axis=1) # (256, 512, 1)
|
| 273 |
+
|
| 274 |
+
return paired, mask
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
###################### Patient Stratified Splitting ######################
|
| 278 |
+
|
| 279 |
+
class PatientStratifiedSplitter:
|
| 280 |
+
"""
|
| 281 |
+
Create patient-stratified train/val/test splits
|
| 282 |
+
Similar to P6 implementation but adapted for P1 data structure
|
| 283 |
+
"""
|
| 284 |
+
|
| 285 |
+
def __init__(self, config: DataConfig):
|
| 286 |
+
self.config = config
|
| 287 |
+
self.config.splits_dir.mkdir(exist_ok=True)
|
| 288 |
+
|
| 289 |
+
def collect_all_patients(self) -> Dict[str, List[str]]:
|
| 290 |
+
"""
|
| 291 |
+
Collect all unique patient IDs from both datasets
|
| 292 |
+
|
| 293 |
+
Returns:
|
| 294 |
+
Dictionary mapping dataset_name -> list of patient IDs
|
| 295 |
+
"""
|
| 296 |
+
all_patients = {}
|
| 297 |
+
|
| 298 |
+
for dataset_name, dataset_config in self.config.datasets.items():
|
| 299 |
+
patients = set()
|
| 300 |
+
|
| 301 |
+
# Path to FLAIR images (standard preprocessing)
|
| 302 |
+
flair_dir = dataset_config['base_path'] / 'FLAIR' / 'Preprocessed' / 'images'
|
| 303 |
+
|
| 304 |
+
if not flair_dir.exists():
|
| 305 |
+
print(f"Warning: {flair_dir} does not exist. Skipping {dataset_name}.")
|
| 306 |
+
continue
|
| 307 |
+
|
| 308 |
+
# Collect all .png files
|
| 309 |
+
for flair_file in flair_dir.glob('*.png'):
|
| 310 |
+
patient_id = extract_patient_id(
|
| 311 |
+
flair_file.name,
|
| 312 |
+
dataset_config['patient_prefix_length']
|
| 313 |
+
)
|
| 314 |
+
patients.add(patient_id)
|
| 315 |
+
|
| 316 |
+
all_patients[dataset_name] = sorted(list(patients))
|
| 317 |
+
print(f"{dataset_name}: {len(all_patients[dataset_name])} patients")
|
| 318 |
+
|
| 319 |
+
return all_patients
|
| 320 |
+
|
| 321 |
+
def create_patient_stratified_splits(self,
|
| 322 |
+
save: bool = True) -> Dict:
|
| 323 |
+
"""
|
| 324 |
+
Create patient-stratified K-fold splits
|
| 325 |
+
|
| 326 |
+
Returns:
|
| 327 |
+
Dictionary containing fold assignments
|
| 328 |
+
"""
|
| 329 |
+
all_patients = self.collect_all_patients()
|
| 330 |
+
|
| 331 |
+
# Combine patients from both datasets
|
| 332 |
+
combined_patients = []
|
| 333 |
+
for dataset_name, patients in all_patients.items():
|
| 334 |
+
combined_patients.extend(patients)
|
| 335 |
+
|
| 336 |
+
combined_patients = np.array(combined_patients)
|
| 337 |
+
total_patients = len(combined_patients)
|
| 338 |
+
|
| 339 |
+
print(f"\nTotal unique patients: {total_patients}")
|
| 340 |
+
|
| 341 |
+
# Step 1: Split into train+val (80%) and test (20%)
|
| 342 |
+
np.random.seed(self.config.random_state)
|
| 343 |
+
test_size = int(total_patients * self.config.test_split)
|
| 344 |
+
|
| 345 |
+
test_indices = np.random.choice(
|
| 346 |
+
total_patients,
|
| 347 |
+
size=test_size,
|
| 348 |
+
replace=False
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
test_patients = combined_patients[test_indices]
|
| 352 |
+
train_val_indices = np.setdiff1d(np.arange(total_patients), test_indices)
|
| 353 |
+
train_val_patients = combined_patients[train_val_indices]
|
| 354 |
+
|
| 355 |
+
print(f"Test patients: {len(test_patients)}")
|
| 356 |
+
print(f"Train+Val patients: {len(train_val_patients)}")
|
| 357 |
+
|
| 358 |
+
# Step 2: Create K-fold splits on train+val patients
|
| 359 |
+
kfold = KFold(
|
| 360 |
+
n_splits=self.config.k_folds,
|
| 361 |
+
shuffle=True,
|
| 362 |
+
random_state=self.config.random_state
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
fold_assignments = {
|
| 366 |
+
'metadata': {
|
| 367 |
+
'total_patients': total_patients,
|
| 368 |
+
'test_patients': len(test_patients),
|
| 369 |
+
'trainval_patients': len(train_val_patients),
|
| 370 |
+
'n_folds': self.config.k_folds,
|
| 371 |
+
'random_seed': self.config.random_state,
|
| 372 |
+
'datasets': list(all_patients.keys())
|
| 373 |
+
},
|
| 374 |
+
'test_set': {
|
| 375 |
+
'patients': test_patients.tolist(),
|
| 376 |
+
'n_patients': len(test_patients)
|
| 377 |
+
},
|
| 378 |
+
'folds': {}
|
| 379 |
+
}
|
| 380 |
+
|
| 381 |
+
for fold_idx, (train_idx, val_idx) in enumerate(kfold.split(train_val_patients)):
|
| 382 |
+
train_patients_fold = train_val_patients[train_idx]
|
| 383 |
+
val_patients_fold = train_val_patients[val_idx]
|
| 384 |
+
|
| 385 |
+
fold_assignments['folds'][f'fold_{fold_idx}'] = {
|
| 386 |
+
'train_patients': train_patients_fold.tolist(),
|
| 387 |
+
'val_patients': val_patients_fold.tolist(),
|
| 388 |
+
'n_train': len(train_patients_fold),
|
| 389 |
+
'n_val': len(val_patients_fold)
|
| 390 |
+
}
|
| 391 |
+
|
| 392 |
+
print(f"Fold {fold_idx}: Train={len(train_patients_fold)}, Val={len(val_patients_fold)}")
|
| 393 |
+
|
| 394 |
+
# Save to JSON
|
| 395 |
+
if save:
|
| 396 |
+
with open(self.config.splits_file, 'w') as f:
|
| 397 |
+
json.dump(fold_assignments, f, indent=2)
|
| 398 |
+
print(f"\n✅ Fold assignments saved to: {self.config.splits_file}")
|
| 399 |
+
|
| 400 |
+
return fold_assignments
|
| 401 |
+
|
| 402 |
+
def load_fold_assignments(self) -> Dict:
|
| 403 |
+
"""Load existing fold assignments from JSON"""
|
| 404 |
+
if not self.config.splits_file.exists():
|
| 405 |
+
raise FileNotFoundError(
|
| 406 |
+
f"Fold assignments not found: {self.config.splits_file}\n"
|
| 407 |
+
f"Run create_patient_stratified_splits() first."
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
with open(self.config.splits_file, 'r') as f:
|
| 411 |
+
fold_assignments = json.load(f)
|
| 412 |
+
|
| 413 |
+
return fold_assignments
|
| 414 |
+
|
| 415 |
+
def verify_patient_separation(self, fold_assignments: Dict) -> bool:
|
| 416 |
+
"""
|
| 417 |
+
Verify no patient appears in multiple folds or in both train/val
|
| 418 |
+
Similar to P6's verification logic
|
| 419 |
+
"""
|
| 420 |
+
print("\n" + "="*60)
|
| 421 |
+
print("VERIFYING PATIENT SEPARATION")
|
| 422 |
+
print("="*60)
|
| 423 |
+
|
| 424 |
+
all_issues = []
|
| 425 |
+
test_patients = set(fold_assignments['test_set']['patients'])
|
| 426 |
+
|
| 427 |
+
# Check 1: No patient in both test and train/val
|
| 428 |
+
for fold_name, fold_data in fold_assignments['folds'].items():
|
| 429 |
+
train_patients = set(fold_data['train_patients'])
|
| 430 |
+
val_patients = set(fold_data['val_patients'])
|
| 431 |
+
|
| 432 |
+
test_train_overlap = test_patients.intersection(train_patients)
|
| 433 |
+
test_val_overlap = test_patients.intersection(val_patients)
|
| 434 |
+
|
| 435 |
+
if test_train_overlap:
|
| 436 |
+
issue = f"{fold_name}: Test-Train overlap: {test_train_overlap}"
|
| 437 |
+
all_issues.append(issue)
|
| 438 |
+
print(f"❌ {issue}")
|
| 439 |
+
|
| 440 |
+
if test_val_overlap:
|
| 441 |
+
issue = f"{fold_name}: Test-Val overlap: {test_val_overlap}"
|
| 442 |
+
all_issues.append(issue)
|
| 443 |
+
print(f"❌ {issue}")
|
| 444 |
+
|
| 445 |
+
# Check 2: No patient in both train and val within same fold
|
| 446 |
+
for fold_name, fold_data in fold_assignments['folds'].items():
|
| 447 |
+
train_patients = set(fold_data['train_patients'])
|
| 448 |
+
val_patients = set(fold_data['val_patients'])
|
| 449 |
+
|
| 450 |
+
train_val_overlap = train_patients.intersection(val_patients)
|
| 451 |
+
if train_val_overlap:
|
| 452 |
+
issue = f"{fold_name}: Train-Val overlap: {train_val_overlap}"
|
| 453 |
+
all_issues.append(issue)
|
| 454 |
+
print(f"❌ {issue}")
|
| 455 |
+
|
| 456 |
+
# Check 3: Each patient in validation exactly once
|
| 457 |
+
all_val_patients = []
|
| 458 |
+
for fold_data in fold_assignments['folds'].values():
|
| 459 |
+
all_val_patients.extend(fold_data['val_patients'])
|
| 460 |
+
|
| 461 |
+
val_patient_counts = {}
|
| 462 |
+
for patient in all_val_patients:
|
| 463 |
+
val_patient_counts[patient] = val_patient_counts.get(patient, 0) + 1
|
| 464 |
+
|
| 465 |
+
for patient, count in val_patient_counts.items():
|
| 466 |
+
if count != 1:
|
| 467 |
+
issue = f"Patient {patient} in validation {count} times (should be 1)"
|
| 468 |
+
all_issues.append(issue)
|
| 469 |
+
print(f"❌ {issue}")
|
| 470 |
+
|
| 471 |
+
if not all_issues:
|
| 472 |
+
print("✅ All patient separation checks passed")
|
| 473 |
+
print("✅ No data leakage detected")
|
| 474 |
+
return True
|
| 475 |
+
else:
|
| 476 |
+
print(f"\n❌ Found {len(all_issues)} issues")
|
| 477 |
+
return False
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
###################### Data Loader ######################
|
| 481 |
+
|
| 482 |
+
class P1DataLoader:
|
| 483 |
+
"""
|
| 484 |
+
Main data loader for P1 experiments
|
| 485 |
+
Handles loading FLAIR and masks, creating paired inputs, TensorFlow datasets
|
| 486 |
+
"""
|
| 487 |
+
|
| 488 |
+
def __init__(self, config: DataConfig):
|
| 489 |
+
self.config = config
|
| 490 |
+
|
| 491 |
+
def get_file_paths(self,
|
| 492 |
+
patient_id: str,
|
| 493 |
+
slice_num: int,
|
| 494 |
+
dataset_name: str,
|
| 495 |
+
preprocessing: str) -> Dict[str, Path]:
|
| 496 |
+
"""
|
| 497 |
+
Construct file paths for a given patient-slice
|
| 498 |
+
|
| 499 |
+
Args:
|
| 500 |
+
patient_id: e.g., "101228" or "c01p01"
|
| 501 |
+
slice_num: Slice number
|
| 502 |
+
dataset_name: 'Local_SAI_GM_sp'
|
| 503 |
+
preprocessing: 'standard' or 'zoomed'
|
| 504 |
+
|
| 505 |
+
Returns:
|
| 506 |
+
Dictionary with paths to FLAIR and mask files
|
| 507 |
+
"""
|
| 508 |
+
dataset_config = self.config.datasets[dataset_name]
|
| 509 |
+
base_path = dataset_config['base_path']
|
| 510 |
+
|
| 511 |
+
# Determine subdirectory based on preprocessing
|
| 512 |
+
if preprocessing == 'standard':
|
| 513 |
+
flair_subdir = 'images'
|
| 514 |
+
gt_subdir = 'images'
|
| 515 |
+
else: # zoomed
|
| 516 |
+
flair_subdir = 'zoomed/images'
|
| 517 |
+
gt_subdir = 'zoomed/images'
|
| 518 |
+
|
| 519 |
+
# Construct paths
|
| 520 |
+
flair_path = base_path / 'FLAIR' / 'Preprocessed' / flair_subdir / f'{patient_id}_{slice_num}.png'
|
| 521 |
+
gm_path = base_path / 'GroundTruth' / gt_subdir / 'GM_Masks' / f'{patient_id}_{slice_num}.png'
|
| 522 |
+
brain_path = base_path / 'GroundTruth' / gt_subdir / 'Brain_Masks' / f'{patient_id}_{slice_num}.png'
|
| 523 |
+
|
| 524 |
+
# Optional: zooming factors (only for zoomed preprocessing)
|
| 525 |
+
zoom_factors_path = None
|
| 526 |
+
if preprocessing == 'zoomed':
|
| 527 |
+
zoom_factors_path = base_path / 'FLAIR' / 'Preprocessed' / 'zoomed' / 'images' / f'{patient_id}_zooming_factors.npy'
|
| 528 |
+
|
| 529 |
+
return {
|
| 530 |
+
'flair': flair_path,
|
| 531 |
+
'gm_mask': gm_path,
|
| 532 |
+
'brain_mask': brain_path,
|
| 533 |
+
'zoom_factors': zoom_factors_path
|
| 534 |
+
}
|
| 535 |
+
|
| 536 |
+
def load_single_slice(self,
|
| 537 |
+
patient_id: str,
|
| 538 |
+
slice_num: int,
|
| 539 |
+
dataset_name: str,
|
| 540 |
+
preprocessing: str,
|
| 541 |
+
class_scenario: str,
|
| 542 |
+
of_z_score: bool = True,
|
| 543 |
+
if_bet: bool = True,
|
| 544 |
+
pre_morph: bool = False) -> Tuple[np.ndarray, np.ndarray]:
|
| 545 |
+
"""
|
| 546 |
+
Load a single patient-slice and create paired input
|
| 547 |
+
|
| 548 |
+
Args:
|
| 549 |
+
patient_id: Patient identifier
|
| 550 |
+
slice_num: Slice number
|
| 551 |
+
dataset_name: 'Local_SAI_GM_sp'
|
| 552 |
+
preprocessing: 'standard' or 'zoomed'
|
| 553 |
+
class_scenario: 'binary'
|
| 554 |
+
|
| 555 |
+
Returns:
|
| 556 |
+
Tuple of (paired_input, combined_mask)
|
| 557 |
+
- paired_input: (256, 512, 1) FLAIR + mask concatenated
|
| 558 |
+
- combined_mask: (256, 256) multi-class labels
|
| 559 |
+
"""
|
| 560 |
+
# Class number
|
| 561 |
+
num_classes = 1 # int(class_scenario[0]) - 1
|
| 562 |
+
|
| 563 |
+
# Get file paths
|
| 564 |
+
paths = self.get_file_paths(patient_id, slice_num, dataset_name, preprocessing)
|
| 565 |
+
|
| 566 |
+
# Load FLAIR
|
| 567 |
+
flair = load_flair_image(paths['flair'], of_z_score=of_z_score)
|
| 568 |
+
|
| 569 |
+
# Load masks
|
| 570 |
+
gm_mask = load_mask_image(paths['gm_mask'])
|
| 571 |
+
brain_mask = load_mask_image(paths['brain_mask'])
|
| 572 |
+
|
| 573 |
+
# Combine masks
|
| 574 |
+
combined_mask = combine_masks(gm_mask, class_scenario, preprocess=pre_morph)
|
| 575 |
+
|
| 576 |
+
# Create paired input
|
| 577 |
+
paired_input, combined_mask = create_paired_input(flair, combined_mask, brain_mask, num_classes=num_classes, if_bet=if_bet)
|
| 578 |
+
|
| 579 |
+
return paired_input, combined_mask
|
| 580 |
+
|
| 581 |
+
def collect_patient_slices(self,
|
| 582 |
+
patient_list: List[str],
|
| 583 |
+
dataset_name: str,
|
| 584 |
+
preprocessing: str) -> List[Tuple[str, int, str]]:
|
| 585 |
+
"""
|
| 586 |
+
Collect all valid slice files for given patients
|
| 587 |
+
FILTERS OUT SLICES WITH ALL EMPTY MASKS
|
| 588 |
+
|
| 589 |
+
Args:
|
| 590 |
+
patient_list: List of patient IDs
|
| 591 |
+
dataset_name: 'Local_SAI_GM_sp'
|
| 592 |
+
preprocessing: 'standard' or 'zoomed'
|
| 593 |
+
|
| 594 |
+
Returns:
|
| 595 |
+
List of tuples (patient_id, slice_num, dataset_name)
|
| 596 |
+
"""
|
| 597 |
+
dataset_config = self.config.datasets[dataset_name]
|
| 598 |
+
slice_min, slice_max = dataset_config['slice_range']
|
| 599 |
+
|
| 600 |
+
patient_slices = []
|
| 601 |
+
skipped_empty = 0
|
| 602 |
+
|
| 603 |
+
for patient_id in patient_list:
|
| 604 |
+
# Check which dataset this patient belongs to
|
| 605 |
+
# Try to find patient in current dataset
|
| 606 |
+
for slice_num in range(slice_min, slice_max + 1):
|
| 607 |
+
paths = self.get_file_paths(patient_id, slice_num, dataset_name, preprocessing)
|
| 608 |
+
|
| 609 |
+
# Check if all required files exist
|
| 610 |
+
if (paths['flair'].exists() and
|
| 611 |
+
paths['gm_mask'].exists() and
|
| 612 |
+
paths['brain_mask'].exists()):
|
| 613 |
+
|
| 614 |
+
# VALIDATION: Check if masks are not all empty
|
| 615 |
+
try:
|
| 616 |
+
gm_mask = load_mask_image(paths['gm_mask'])
|
| 617 |
+
brain_mask = load_mask_image(paths['brain_mask'])
|
| 618 |
+
|
| 619 |
+
# Only add if at least one mask has content
|
| 620 |
+
if is_valid_slice(gm_mask):
|
| 621 |
+
patient_slices.append((patient_id, slice_num, dataset_name))
|
| 622 |
+
else:
|
| 623 |
+
skipped_empty += 1
|
| 624 |
+
|
| 625 |
+
except Exception as e:
|
| 626 |
+
print(f"Warning: Could not validate {patient_id}_{slice_num}: {e}")
|
| 627 |
+
skipped_empty += 1
|
| 628 |
+
|
| 629 |
+
if skipped_empty > 0:
|
| 630 |
+
print(f" ⚠️ Skipped {skipped_empty} slices with empty masks")
|
| 631 |
+
|
| 632 |
+
return patient_slices
|
| 633 |
+
|
| 634 |
+
def create_dataset_for_fold(self,
|
| 635 |
+
fold_id: int,
|
| 636 |
+
split: str,
|
| 637 |
+
preprocessing: str,
|
| 638 |
+
class_scenario: str,
|
| 639 |
+
batch_size: int = 1,
|
| 640 |
+
shuffle: bool = True,
|
| 641 |
+
use_z_scored: bool = True,
|
| 642 |
+
bet: bool = False) -> tf.data.Dataset:
|
| 643 |
+
"""
|
| 644 |
+
Create TensorFlow dataset for a specific fold and split
|
| 645 |
+
|
| 646 |
+
Args:
|
| 647 |
+
fold_id: Fold number (0-4)
|
| 648 |
+
split: 'train', 'val', or 'test'
|
| 649 |
+
preprocessing: 'standard' or 'zoomed'
|
| 650 |
+
class_scenario: 'binary'
|
| 651 |
+
batch_size: Batch size
|
| 652 |
+
shuffle: Whether to shuffle data
|
| 653 |
+
|
| 654 |
+
Returns:
|
| 655 |
+
tf.data.Dataset yielding (paired_input, combined_mask) batches
|
| 656 |
+
"""
|
| 657 |
+
# Load fold assignments
|
| 658 |
+
splitter = PatientStratifiedSplitter(self.config)
|
| 659 |
+
fold_assignments = splitter.load_fold_assignments()
|
| 660 |
+
|
| 661 |
+
# Get patient list for this split
|
| 662 |
+
if split == 'test':
|
| 663 |
+
patient_list = fold_assignments['test_set']['patients']
|
| 664 |
+
else:
|
| 665 |
+
fold_key = f'fold_{fold_id}'
|
| 666 |
+
if split == 'train':
|
| 667 |
+
patient_list = fold_assignments['folds'][fold_key]['train_patients']
|
| 668 |
+
elif split == 'val':
|
| 669 |
+
patient_list = fold_assignments['folds'][fold_key]['val_patients']
|
| 670 |
+
else:
|
| 671 |
+
raise ValueError(f"Unknown split: {split}")
|
| 672 |
+
|
| 673 |
+
print(f"\nCreating dataset for fold {fold_id}, split '{split}'")
|
| 674 |
+
print(f"Patients: {len(patient_list)}")
|
| 675 |
+
|
| 676 |
+
# Collect all patient-slices from both datasets
|
| 677 |
+
all_patient_slices = []
|
| 678 |
+
|
| 679 |
+
for dataset_name in self.config.datasets.keys():
|
| 680 |
+
# Filter patient list to only include patients from this dataset
|
| 681 |
+
# This is done by checking patient ID prefix
|
| 682 |
+
dataset_patients = [p for p in patient_list]
|
| 683 |
+
|
| 684 |
+
patient_slices = self.collect_patient_slices(
|
| 685 |
+
dataset_patients,
|
| 686 |
+
dataset_name,
|
| 687 |
+
preprocessing
|
| 688 |
+
)
|
| 689 |
+
all_patient_slices.extend(patient_slices)
|
| 690 |
+
|
| 691 |
+
print(f"Total slices: {len(all_patient_slices)}")
|
| 692 |
+
|
| 693 |
+
if len(all_patient_slices) == 0:
|
| 694 |
+
raise ValueError(f"No data found for fold {fold_id}, split '{split}'")
|
| 695 |
+
|
| 696 |
+
# Create TensorFlow dataset
|
| 697 |
+
def data_generator():
|
| 698 |
+
"""Generator function for tf.data.Dataset"""
|
| 699 |
+
for patient_id, slice_num, dataset_name in all_patient_slices:
|
| 700 |
+
try:
|
| 701 |
+
paired_input, combined_mask = self.load_single_slice(
|
| 702 |
+
patient_id, slice_num, dataset_name,
|
| 703 |
+
preprocessing, class_scenario
|
| 704 |
+
)
|
| 705 |
+
yield paired_input, combined_mask, patient_id, slice_num
|
| 706 |
+
except Exception as e:
|
| 707 |
+
print(f"Error loading {patient_id}_{slice_num}: {e}")
|
| 708 |
+
continue
|
| 709 |
+
|
| 710 |
+
# Create dataset
|
| 711 |
+
dataset = tf.data.Dataset.from_generator(
|
| 712 |
+
data_generator,
|
| 713 |
+
output_signature=(
|
| 714 |
+
tf.TensorSpec(shape=(256, 512, 1), dtype=tf.float32), # concatenated image
|
| 715 |
+
tf.TensorSpec(shape=(256, 256), dtype=tf.uint8), # multi-level mask
|
| 716 |
+
tf.TensorSpec(shape=(), dtype=tf.string), # patient_id
|
| 717 |
+
tf.TensorSpec(shape=(), dtype=tf.int32) # slice_num
|
| 718 |
+
)
|
| 719 |
+
)
|
| 720 |
+
|
| 721 |
+
# ── Cache BEFORE shuffle/batch ──────────────────────────────────────
|
| 722 |
+
# On epoch 1 the generator runs once and all samples are stored
|
| 723 |
+
# in RAM (~1 GB). From epoch 2 onward no disk I/O occurs at all.
|
| 724 |
+
# Placing cache HERE (on unbatched, unshuffled samples) means:
|
| 725 |
+
# • The expensive load/decode/combine step is paid only once.
|
| 726 |
+
# • Shuffle re-randomises the order freshly each epoch (because
|
| 727 |
+
# reshuffle_each_iteration=True is the default).
|
| 728 |
+
# • Batch composition therefore differs every epoch as desired.
|
| 729 |
+
dataset = dataset.cache()
|
| 730 |
+
|
| 731 |
+
# Shuffle if training (acts on the in-RAM cache every epoch)
|
| 732 |
+
if shuffle and split == 'train':
|
| 733 |
+
dataset = dataset.shuffle(
|
| 734 |
+
buffer_size=len(all_patient_slices),
|
| 735 |
+
reshuffle_each_iteration=True # new random order each epoch
|
| 736 |
+
)
|
| 737 |
+
|
| 738 |
+
# Batch and prefetch
|
| 739 |
+
dataset = dataset.batch(batch_size)
|
| 740 |
+
dataset = dataset.prefetch(tf.data.AUTOTUNE)
|
| 741 |
+
|
| 742 |
+
return dataset
|
| 743 |
+
|
| 744 |
+
|
| 745 |
+
###################### Testing & Validation Functions ######################
|
| 746 |
+
|
| 747 |
+
def test_data_loading():
|
| 748 |
+
"""Test data loading functionality"""
|
| 749 |
+
print("\n" + "="*60)
|
| 750 |
+
print("TESTING DATA LOADING")
|
| 751 |
+
print("="*60)
|
| 752 |
+
|
| 753 |
+
config = DataConfig()
|
| 754 |
+
|
| 755 |
+
# Test 1: Create fold assignments
|
| 756 |
+
print("\n[TEST 1] Creating patient stratified splits...")
|
| 757 |
+
splitter = PatientStratifiedSplitter(config)
|
| 758 |
+
fold_assignments = splitter.create_patient_stratified_splits(save=True)
|
| 759 |
+
|
| 760 |
+
# Verify patient separation
|
| 761 |
+
is_valid = splitter.verify_patient_separation(fold_assignments)
|
| 762 |
+
|
| 763 |
+
if not is_valid:
|
| 764 |
+
print("❌ Patient separation verification failed!")
|
| 765 |
+
return False
|
| 766 |
+
|
| 767 |
+
# Test 2: Load a single slice
|
| 768 |
+
print("\n[TEST 2] Loading single slice...")
|
| 769 |
+
loader = P1DataLoader(config)
|
| 770 |
+
|
| 771 |
+
# Get a test patient from fold 0 train set
|
| 772 |
+
test_patient = fold_assignments['folds']['fold_0']['train_patients'][0]
|
| 773 |
+
|
| 774 |
+
# Determine which dataset this patient belongs to
|
| 775 |
+
if test_patient.startswith('1'):
|
| 776 |
+
test_dataset = 'Local_SAI_GM_sp'
|
| 777 |
+
test_slice = 10 # Middle of 8-15 range
|
| 778 |
+
else:
|
| 779 |
+
raise ValueError
|
| 780 |
+
|
| 781 |
+
|
| 782 |
+
try:
|
| 783 |
+
paired_input, combined_mask = loader.load_single_slice(
|
| 784 |
+
test_patient, test_slice, test_dataset,
|
| 785 |
+
'standard', 'binary'
|
| 786 |
+
)
|
| 787 |
+
|
| 788 |
+
print(f"✅ Loaded slice {test_patient}_{test_slice}")
|
| 789 |
+
print(f" Paired input shape: {paired_input.shape}")
|
| 790 |
+
print(f" Combined mask shape: {combined_mask.shape}")
|
| 791 |
+
print(f" Mask unique values: {np.unique(combined_mask)}")
|
| 792 |
+
|
| 793 |
+
except Exception as e:
|
| 794 |
+
print(f"❌ Failed to load slice: {e}")
|
| 795 |
+
return False
|
| 796 |
+
|
| 797 |
+
# Test 3: Create TensorFlow dataset
|
| 798 |
+
print("\n[TEST 3] Creating TensorFlow dataset...")
|
| 799 |
+
try:
|
| 800 |
+
dataset = loader.create_dataset_for_fold(
|
| 801 |
+
fold_id=0,
|
| 802 |
+
split='train',
|
| 803 |
+
preprocessing='standard',
|
| 804 |
+
class_scenario='binary',
|
| 805 |
+
batch_size=2,
|
| 806 |
+
shuffle=True
|
| 807 |
+
)
|
| 808 |
+
|
| 809 |
+
# Get first batch
|
| 810 |
+
for batch_paired, batch_masks, _, _ in dataset.take(1):
|
| 811 |
+
print(f"✅ Created dataset")
|
| 812 |
+
print(f" Batch paired input shape: {batch_paired.shape}")
|
| 813 |
+
print(f" Batch masks shape: {batch_masks.shape}")
|
| 814 |
+
print(f" Paired input dtype: {batch_paired.dtype}")
|
| 815 |
+
print(f" Masks dtype: {batch_masks.dtype}")
|
| 816 |
+
|
| 817 |
+
except Exception as e:
|
| 818 |
+
print(f"❌ Failed to create dataset: {e}")
|
| 819 |
+
return False
|
| 820 |
+
|
| 821 |
+
print("\n" + "="*60)
|
| 822 |
+
print("✅ ALL TESTS PASSED")
|
| 823 |
+
print("="*60)
|
| 824 |
+
|
| 825 |
+
return True
|
| 826 |
+
|
| 827 |
+
|
| 828 |
+
###################### Main Execution ######################
|
| 829 |
+
|
| 830 |
+
if __name__ == "__main__":
|
| 831 |
+
# Run tests
|
| 832 |
+
success = test_data_loading()
|
| 833 |
+
|
| 834 |
+
if success:
|
| 835 |
+
print("\n" + "="*60)
|
| 836 |
+
print("DATA LOADER READY FOR USE")
|
| 837 |
+
print("="*60)
|
| 838 |
+
print("\nNext steps:")
|
| 839 |
+
print("1. Verify fold_assignments.json created in data_splits/")
|
| 840 |
+
print("2. Check that all file paths are correct for your system")
|
| 841 |
+
print("3. Proceed to model implementation")
|
| 842 |
+
else:
|
| 843 |
+
print("\n" + "="*60)
|
| 844 |
+
print("❌ DATA LOADER TESTS FAILED")
|
| 845 |
+
print("="*60)
|
| 846 |
+
print("\nPlease fix the issues above before proceeding")
|
| 847 |
+
|
models/for_GM/model_training_scripts/p1_pix2pix_var5.py
ADDED
|
@@ -0,0 +1,1313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
P1 Article - Specialized Gray Matter (GM) Segmentation with U-Net Models - Journal Paper Implementation
|
| 3 |
+
|
| 4 |
+
Features:
|
| 5 |
+
- Multi-channel Generator output (softmax)
|
| 6 |
+
- Attention-Weighted PatchGAN Discriminator
|
| 7 |
+
- Adaptive hybrid loss (Weighted Categorical Cross-Entropy & Focal Dice)
|
| 8 |
+
- One-hot encoded targets
|
| 9 |
+
- Class weight computation per fold
|
| 10 |
+
- Optimized for severe class imbalance
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import tensorflow as tf
|
| 14 |
+
import os
|
| 15 |
+
import time
|
| 16 |
+
import numpy as np
|
| 17 |
+
import matplotlib.pyplot as plt
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from tqdm import tqdm
|
| 20 |
+
import json
|
| 21 |
+
|
| 22 |
+
from unet_model import build_unet_3class
|
| 23 |
+
|
| 24 |
+
# Import data loader
|
| 25 |
+
from p1_data_loader import DataConfig, P1DataLoader
|
| 26 |
+
|
| 27 |
+
# Import utilities from baseline
|
| 28 |
+
from utility_functions import (
|
| 29 |
+
clear_gpu_memory,
|
| 30 |
+
get_gpu_memory_info,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
# Import class weights utility
|
| 34 |
+
from p1_compute_class_weights import compute_and_save_class_weights, load_class_weights
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
print("TensorFlow Version:", tf.__version__)
|
| 38 |
+
|
| 39 |
+
###################### GPU Configuration ######################
|
| 40 |
+
|
| 41 |
+
# Configure GPU memory growth
|
| 42 |
+
physical_devices = tf.config.list_physical_devices('GPU')
|
| 43 |
+
if physical_devices:
|
| 44 |
+
try:
|
| 45 |
+
for device in physical_devices:
|
| 46 |
+
tf.config.experimental.set_memory_growth(device, True)
|
| 47 |
+
print("✅ GPU memory growth enabled")
|
| 48 |
+
print(f" Available GPUs: {len(physical_devices)}")
|
| 49 |
+
except RuntimeError as e:
|
| 50 |
+
print(f"GPU configuration error: {e}")
|
| 51 |
+
else:
|
| 52 |
+
print("⚠️ No GPU detected - training will be slow")
|
| 53 |
+
|
| 54 |
+
"""
|
| 55 |
+
GPU Memory Management for Sequential Experiments
|
| 56 |
+
To properly release memory between experiments
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
###################### Target Preparation ######################
|
| 60 |
+
|
| 61 |
+
def prepare_inputs(paired_input, target_mask, num_classes):
|
| 62 |
+
"""
|
| 63 |
+
Prepare inputs for training
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
paired_input: (bs, 256, 512, 1) with FLAIR + mask
|
| 67 |
+
target_mask: (bs, 256, 256) with class labels [0, num_classes-1]
|
| 68 |
+
num_classes: number of classes
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
flair_normalized: FLAIR normalized to [-1, 1]
|
| 72 |
+
target_onehot: One-hot encoded mask (bs, 256, 256, num_classes)
|
| 73 |
+
"""
|
| 74 |
+
# Extract FLAIR, previously normalized to [-1, 1]
|
| 75 |
+
flair_normalized = paired_input[:, :, :256, :]
|
| 76 |
+
|
| 77 |
+
# One-hot encode target
|
| 78 |
+
target_onehot = tf.one_hot(target_mask, depth=num_classes, dtype=tf.float32)
|
| 79 |
+
|
| 80 |
+
return flair_normalized, target_onehot
|
| 81 |
+
|
| 82 |
+
###################### Metrics Calculation ######################
|
| 83 |
+
|
| 84 |
+
def compute_classwise_metrics(all_val_true, all_val_pred, num_classes, exclude_class=None):
|
| 85 |
+
"""
|
| 86 |
+
Compute class-wise Dice, Precision, and Recall for validation predictions.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
all_val_true: List of one-hot encoded ground truth tensors
|
| 90 |
+
all_val_pred: List of softmax output tensors from generator
|
| 91 |
+
num_classes: Number of classes (2)
|
| 92 |
+
exclude_class: Class to exclude from metric calculation (e.g., 0 for background)
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
Dictionary containing class-wise and mean metrics
|
| 96 |
+
"""
|
| 97 |
+
# Concatenate all batches
|
| 98 |
+
y_true_concat = tf.concat(all_val_true, axis=0) # Shape: (N, H, W, num_classes)
|
| 99 |
+
y_pred_concat = tf.concat(all_val_pred, axis=0) # Shape: (N, H, W, num_classes)
|
| 100 |
+
|
| 101 |
+
# Flatten spatial dimensions: (N*H*W, num_classes)
|
| 102 |
+
y_true_flat = tf.reshape(y_true_concat, [-1, num_classes])
|
| 103 |
+
y_pred_flat = tf.reshape(y_pred_concat, [-1, num_classes])
|
| 104 |
+
|
| 105 |
+
# Convert predictions to one-hot (argmax)
|
| 106 |
+
y_pred_classes = tf.argmax(y_pred_flat, axis=-1)
|
| 107 |
+
y_pred_onehot = tf.one_hot(y_pred_classes, depth=num_classes)
|
| 108 |
+
|
| 109 |
+
# Convert to numpy for easier computation
|
| 110 |
+
y_true_np = y_true_flat.numpy()
|
| 111 |
+
y_pred_np = y_pred_onehot.numpy()
|
| 112 |
+
|
| 113 |
+
metrics = {
|
| 114 |
+
'dice': {},
|
| 115 |
+
'precision': {},
|
| 116 |
+
'recall': {}
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
classes_to_evaluate = [c for c in range(num_classes) if c != exclude_class]
|
| 120 |
+
|
| 121 |
+
for class_idx in classes_to_evaluate:
|
| 122 |
+
# Extract binary masks for this class
|
| 123 |
+
true_class = y_true_np[:, class_idx]
|
| 124 |
+
pred_class = y_pred_np[:, class_idx]
|
| 125 |
+
|
| 126 |
+
# True Positives, False Positives, False Negatives
|
| 127 |
+
TP = np.sum((true_class == 1) & (pred_class == 1))
|
| 128 |
+
FP = np.sum((true_class == 0) & (pred_class == 1))
|
| 129 |
+
FN = np.sum((true_class == 1) & (pred_class == 0))
|
| 130 |
+
|
| 131 |
+
# Dice Score: 2*TP / (2*TP + FP + FN)
|
| 132 |
+
dice = (2 * TP) / (2 * TP + FP + FN + 1e-7)
|
| 133 |
+
|
| 134 |
+
# Precision: TP / (TP + FP)
|
| 135 |
+
precision = TP / (TP + FP + 1e-7)
|
| 136 |
+
|
| 137 |
+
# Recall (Sensitivity): TP / (TP + FN)
|
| 138 |
+
recall = TP / (TP + FN + 1e-7)
|
| 139 |
+
|
| 140 |
+
metrics['dice'][f'class_{class_idx}'] = float(dice)
|
| 141 |
+
metrics['precision'][f'class_{class_idx}'] = float(precision)
|
| 142 |
+
metrics['recall'][f'class_{class_idx}'] = float(recall)
|
| 143 |
+
|
| 144 |
+
# Compute mean metrics (excluding the excluded class)
|
| 145 |
+
metrics['dice']['mean'] = np.mean([v for v in metrics['dice'].values()])
|
| 146 |
+
metrics['precision']['mean'] = np.mean([v for v in metrics['precision'].values()])
|
| 147 |
+
metrics['recall']['mean'] = np.mean([v for v in metrics['recall'].values()])
|
| 148 |
+
|
| 149 |
+
return metrics
|
| 150 |
+
|
| 151 |
+
###################### Experiment Configuration ######################
|
| 152 |
+
|
| 153 |
+
class ExperimentConfig:
|
| 154 |
+
"""Configuration for multi-class pix2pix experiment"""
|
| 155 |
+
|
| 156 |
+
def __init__(self,
|
| 157 |
+
variant: int = 1,
|
| 158 |
+
preprocessing: str = 'standard',
|
| 159 |
+
class_scenario: str = 'binary',
|
| 160 |
+
fold_id: int = 0):
|
| 161 |
+
|
| 162 |
+
# Experiment identification
|
| 163 |
+
self.variant = variant
|
| 164 |
+
self.preprocessing = preprocessing # 'standard' or 'zoomed'
|
| 165 |
+
self.class_scenario = class_scenario # 'binary'
|
| 166 |
+
self.fold_id = fold_id
|
| 167 |
+
|
| 168 |
+
# Experiment name
|
| 169 |
+
self.exp_name = f"exp_{variant}_multiclass_{preprocessing}_{class_scenario}_fold{fold_id}"
|
| 170 |
+
|
| 171 |
+
# Number of classes
|
| 172 |
+
self.num_classes = 2 if class_scenario == 'binary' else 2
|
| 173 |
+
|
| 174 |
+
# Training hyperparameters
|
| 175 |
+
self.batch_size = 4
|
| 176 |
+
self.img_width = 256
|
| 177 |
+
self.img_height = 256
|
| 178 |
+
self.epochs = 20
|
| 179 |
+
|
| 180 |
+
# Loss weights
|
| 181 |
+
self.lambda_seg = 50 # seg loss weight
|
| 182 |
+
self.lambda_gan = 1 # GAN loss weight
|
| 183 |
+
|
| 184 |
+
# Adaptive loss parameters
|
| 185 |
+
self.focal_gamma = 0.5 # Focal loss focusing parameter
|
| 186 |
+
self.beta_threshold = 0.25 # Transition at epoch 15/60
|
| 187 |
+
self.beta_smoothness = 0.05 # Transition width
|
| 188 |
+
self.use_focal_alpha = True # Use class weights in focal loss
|
| 189 |
+
|
| 190 |
+
# Optimizer parameters
|
| 191 |
+
self.learning_rate = 2e-4
|
| 192 |
+
self.beta_1 = 0.9
|
| 193 |
+
|
| 194 |
+
# Attention parameters
|
| 195 |
+
self.attention_weight = 2.0 # How much to upweight lesion regions
|
| 196 |
+
|
| 197 |
+
# Paths
|
| 198 |
+
self.results_dir = Path(f"results_fold_{fold_id}_var_{variant}_bet_zscore_gm")
|
| 199 |
+
self.models_dir = self.results_dir / "models" / f"{preprocessing}_{class_scenario}"
|
| 200 |
+
self.figures_dir = self.results_dir / "figures" / f"{preprocessing}_{class_scenario}" / f"fold_{fold_id}"
|
| 201 |
+
self.logs_dir = self.results_dir / "logs" / f"{preprocessing}_{class_scenario}" / f"fold_{fold_id}"
|
| 202 |
+
|
| 203 |
+
# Create directories
|
| 204 |
+
self.models_dir.mkdir(parents=True, exist_ok=True)
|
| 205 |
+
self.figures_dir.mkdir(parents=True, exist_ok=True)
|
| 206 |
+
self.logs_dir.mkdir(parents=True, exist_ok=True)
|
| 207 |
+
|
| 208 |
+
# Checkpoint configuration
|
| 209 |
+
self.checkpoint_dir = self.models_dir / f"fold_{fold_id}"
|
| 210 |
+
self.checkpoint_dir.mkdir(exist_ok=True)
|
| 211 |
+
|
| 212 |
+
# Class weights directory
|
| 213 |
+
self.weights_dir = Path("class_weights_gm")
|
| 214 |
+
self.weights_dir.mkdir(exist_ok=True)
|
| 215 |
+
|
| 216 |
+
# Save configuration
|
| 217 |
+
self.save_config()
|
| 218 |
+
|
| 219 |
+
def save_config(self):
|
| 220 |
+
"""Save experiment configuration to JSON"""
|
| 221 |
+
config_dict = {
|
| 222 |
+
'variant': self.variant,
|
| 223 |
+
'variant_name': 'Multiclass_AttentionD_AdaptiveLoss',
|
| 224 |
+
'preprocessing': self.preprocessing,
|
| 225 |
+
'class_scenario': self.class_scenario,
|
| 226 |
+
'fold_id': self.fold_id,
|
| 227 |
+
'num_classes': self.num_classes,
|
| 228 |
+
'batch_size': self.batch_size,
|
| 229 |
+
'epochs': self.epochs,
|
| 230 |
+
'lambda_seg': self.lambda_seg,
|
| 231 |
+
'lambda_gan': self.lambda_gan,
|
| 232 |
+
'focal_gamma': self.focal_gamma,
|
| 233 |
+
'beta_threshold': self.beta_threshold,
|
| 234 |
+
'beta_smoothness': self.beta_smoothness,
|
| 235 |
+
'learning_rate': self.learning_rate,
|
| 236 |
+
'beta_1': self.beta_1,
|
| 237 |
+
'attention_weight': self.attention_weight,
|
| 238 |
+
'innovation': 'Phase-transitioning segmentation loss (Weighted CE → Focal Loss)'
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
config_file = self.checkpoint_dir / "config.json"
|
| 242 |
+
with open(config_file, 'w') as f:
|
| 243 |
+
json.dump(config_dict, f, indent=2)
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
###################### Model Architecture ######################
|
| 247 |
+
|
| 248 |
+
def downsample(filters, size, apply_norm=True, use_groupnorm=True):
|
| 249 |
+
"""
|
| 250 |
+
Downsample block for encoder
|
| 251 |
+
|
| 252 |
+
Args:
|
| 253 |
+
filters: Number of filters
|
| 254 |
+
size: Kernel size
|
| 255 |
+
apply_norm: Whether to apply normalization
|
| 256 |
+
use_groupnorm: If True, use GroupNorm (better for batch_size=1)
|
| 257 |
+
If False, use BatchNorm (original pix2pix)
|
| 258 |
+
"""
|
| 259 |
+
initializer = tf.random_normal_initializer(0., 0.02)
|
| 260 |
+
|
| 261 |
+
result = tf.keras.Sequential()
|
| 262 |
+
result.add(
|
| 263 |
+
tf.keras.layers.Conv2D(
|
| 264 |
+
filters, size, strides=2, padding='same',
|
| 265 |
+
kernel_initializer=initializer, use_bias=False
|
| 266 |
+
)
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
if apply_norm:
|
| 270 |
+
if use_groupnorm:
|
| 271 |
+
# ✅ GroupNorm: Independent of batch size, no train/inference mismatch
|
| 272 |
+
# Use 32 groups (standard), or filters//8 if filters < 32
|
| 273 |
+
groups = min(32, max(1, filters // 8))
|
| 274 |
+
result.add(tf.keras.layers.GroupNormalization(groups=groups))
|
| 275 |
+
else:
|
| 276 |
+
# Original BatchNorm (can cause NaN with batch_size=1 at inference)
|
| 277 |
+
result.add(tf.keras.layers.BatchNormalization(momentum=0.99))
|
| 278 |
+
|
| 279 |
+
result.add(tf.keras.layers.LeakyReLU())
|
| 280 |
+
|
| 281 |
+
return result
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def build_attention_discriminator(num_classes: int, input_channels: int = 1,
|
| 285 |
+
attention_weight: float = 2.0, use_groupnorm: bool = True):
|
| 286 |
+
"""
|
| 287 |
+
Build Attention-Weighted PatchGAN Discriminator
|
| 288 |
+
|
| 289 |
+
Args:
|
| 290 |
+
num_classes: Number of classes in target mask
|
| 291 |
+
input_channels: Number of input channels
|
| 292 |
+
attention_weight: Multiplier for lesion regions (>1.0 upweights lesions)
|
| 293 |
+
use_groupnorm: If True, use GroupNorm instead of BatchNorm
|
| 294 |
+
|
| 295 |
+
Returns:
|
| 296 |
+
Discriminator model
|
| 297 |
+
"""
|
| 298 |
+
initializer = tf.random_normal_initializer(0., 0.02)
|
| 299 |
+
|
| 300 |
+
# Input: FLAIR image
|
| 301 |
+
inp = tf.keras.layers.Input(
|
| 302 |
+
shape=[256, 256, input_channels],
|
| 303 |
+
name='input_image'
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
# ✅ Target: Multi-channel one-hot mask
|
| 307 |
+
tar = tf.keras.layers.Input(
|
| 308 |
+
shape=[256, 256, num_classes],
|
| 309 |
+
name='target_mask'
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
# ✅ Compute spatial attention map from target mask
|
| 313 |
+
# attention_map: (bs, 256, 256, 1)
|
| 314 |
+
# Background (class 0) gets weight 1.0, lesions get attention_weight
|
| 315 |
+
class_indices = tf.argmax(tar, axis=-1, output_type=tf.int32) # (bs, 256, 256)
|
| 316 |
+
attention_map = tf.where(
|
| 317 |
+
class_indices == 0,
|
| 318 |
+
tf.ones_like(class_indices, dtype=tf.float32), # Background: weight 1.0
|
| 319 |
+
tf.ones_like(class_indices, dtype=tf.float32) * attention_weight # Lesions: upweighted
|
| 320 |
+
)
|
| 321 |
+
attention_map = tf.expand_dims(attention_map, axis=-1) # (bs, 256, 256, 1)
|
| 322 |
+
|
| 323 |
+
# Concatenate input and target
|
| 324 |
+
x = tf.keras.layers.concatenate([inp, tar]) # (bs, 256, 256, 1+num_classes)
|
| 325 |
+
|
| 326 |
+
# Standard PatchGAN backbone
|
| 327 |
+
down1 = downsample(64, 4, apply_norm=False, use_groupnorm=use_groupnorm)(x) # (bs, 128, 128, 64)
|
| 328 |
+
down2 = downsample(128, 4, use_groupnorm=use_groupnorm)(down1) # (bs, 64, 64, 128)
|
| 329 |
+
down3 = downsample(256, 4, use_groupnorm=use_groupnorm)(down2) # (bs, 32, 32, 256)
|
| 330 |
+
|
| 331 |
+
zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3) # (bs, 34, 34, 256)
|
| 332 |
+
conv = tf.keras.layers.Conv2D(
|
| 333 |
+
512, 4, strides=1,
|
| 334 |
+
kernel_initializer=initializer,
|
| 335 |
+
use_bias=False
|
| 336 |
+
)(zero_pad1) # (bs, 31, 31, 512)
|
| 337 |
+
|
| 338 |
+
if use_groupnorm:
|
| 339 |
+
batchnorm1 = tf.keras.layers.GroupNormalization(groups=8)(conv)
|
| 340 |
+
else:
|
| 341 |
+
batchnorm1 = tf.keras.layers.BatchNormalization()(conv)
|
| 342 |
+
leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)
|
| 343 |
+
|
| 344 |
+
zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu) # (bs, 33, 33, 512)
|
| 345 |
+
|
| 346 |
+
# Output patch predictions
|
| 347 |
+
patch_output = tf.keras.layers.Conv2D(
|
| 348 |
+
1, 4, strides=1,
|
| 349 |
+
kernel_initializer=initializer,
|
| 350 |
+
name='patch_predictions'
|
| 351 |
+
)(zero_pad2) # (bs, 30, 30, 1)
|
| 352 |
+
|
| 353 |
+
# ✅ Apply spatial attention to patch predictions
|
| 354 |
+
# Downsample attention map to match patch size (256 -> 30)
|
| 355 |
+
attention_downsampled = tf.keras.layers.AveragePooling2D(
|
| 356 |
+
pool_size=(9, 9), strides=(8, 8), padding='same'
|
| 357 |
+
)(attention_map) # Approximate (bs, 30, 30, 1)
|
| 358 |
+
|
| 359 |
+
# Ensure exact shape match
|
| 360 |
+
attention_resized = tf.image.resize(
|
| 361 |
+
attention_downsampled,
|
| 362 |
+
[tf.shape(patch_output)[1], tf.shape(patch_output)[2]],
|
| 363 |
+
method='bilinear'
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
# Apply attention weighting
|
| 367 |
+
weighted_output = patch_output * attention_resized
|
| 368 |
+
|
| 369 |
+
return tf.keras.Model(
|
| 370 |
+
inputs=[inp, tar],
|
| 371 |
+
outputs=weighted_output,
|
| 372 |
+
name='AttentionDiscriminator'
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
###################### Beta Scheduling ######################
|
| 377 |
+
|
| 378 |
+
def smooth_step(x, threshold=0.5, smoothness=0.1):
|
| 379 |
+
"""
|
| 380 |
+
Smooth step function for phase transition
|
| 381 |
+
|
| 382 |
+
Creates smooth transition around threshold value using sigmoid.
|
| 383 |
+
|
| 384 |
+
Args:
|
| 385 |
+
x: Current progress (typically epoch / total_epochs)
|
| 386 |
+
threshold: Center point of transition (e.g., 0.5 for epoch 25/50)
|
| 387 |
+
smoothness: Width of transition (smaller = sharper, larger = smoother)
|
| 388 |
+
|
| 389 |
+
Returns:
|
| 390 |
+
Value in [0, 1] representing transition progress
|
| 391 |
+
- x << threshold: returns ≈ 0
|
| 392 |
+
- x ≈ threshold: returns ≈ 0.5
|
| 393 |
+
- x >> threshold: returns ≈ 1
|
| 394 |
+
|
| 395 |
+
Example:
|
| 396 |
+
epoch_progress = 0.3 # Epoch 15/50
|
| 397 |
+
beta = smooth_step(0.3, threshold=0.5, smoothness=0.1)
|
| 398 |
+
# beta ≈ 0.05 (mostly phase 1)
|
| 399 |
+
|
| 400 |
+
epoch_progress = 0.5 # Epoch 25/50
|
| 401 |
+
beta = smooth_step(0.5, threshold=0.5, smoothness=0.1)
|
| 402 |
+
# beta ≈ 0.5 (equal mix)
|
| 403 |
+
|
| 404 |
+
epoch_progress = 0.7 # Epoch 35/50
|
| 405 |
+
beta = smooth_step(0.7, threshold=0.5, smoothness=0.1)
|
| 406 |
+
# beta ≈ 0.95 (mostly phase 2)
|
| 407 |
+
"""
|
| 408 |
+
# Sigmoid centered at threshold
|
| 409 |
+
# (x - threshold) / smoothness controls steepness
|
| 410 |
+
return tf.sigmoid((x - threshold) / smoothness)
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
def compute_beta_schedule(current_epoch, total_epochs,
|
| 414 |
+
threshold=0.5, smoothness=0.1):
|
| 415 |
+
"""
|
| 416 |
+
Compute beta value for current epoch
|
| 417 |
+
|
| 418 |
+
Args:
|
| 419 |
+
current_epoch: Current epoch number (0-indexed)
|
| 420 |
+
total_epochs: Total number of epochs
|
| 421 |
+
threshold: Transition center (0.5 = midpoint)
|
| 422 |
+
smoothness: Transition width
|
| 423 |
+
|
| 424 |
+
Returns:
|
| 425 |
+
Beta value in [0, 1]
|
| 426 |
+
"""
|
| 427 |
+
epoch_progress = tf.cast(current_epoch, tf.float32) / tf.cast(total_epochs, tf.float32)
|
| 428 |
+
beta = smooth_step(epoch_progress, threshold, smoothness)
|
| 429 |
+
return beta
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
###################### Loss Functions ######################
|
| 433 |
+
|
| 434 |
+
def unified_focal_loss(y_true, y_pred, gamma=2.0, alpha=None, exclude_class=None):
|
| 435 |
+
"""
|
| 436 |
+
Unified Focal Loss
|
| 437 |
+
|
| 438 |
+
Focal loss down-weights easy examples and focuses on hard examples.
|
| 439 |
+
Particularly effective for class imbalance and boundary regions.
|
| 440 |
+
|
| 441 |
+
Args:
|
| 442 |
+
y_true: Ground truth labels (bs, H, W, num_classes) one-hot encoded
|
| 443 |
+
y_pred: Predicted probabilities (bs, H, W, num_classes) from softmax
|
| 444 |
+
gamma: Focusing parameter (default 2.0)
|
| 445 |
+
- gamma=0: equivalent to cross-entropy
|
| 446 |
+
- gamma>0: down-weights easy examples
|
| 447 |
+
- Higher gamma = more focus on hard examples
|
| 448 |
+
alpha: Per-class balancing weights (num_classes,) - optional, trainable
|
| 449 |
+
- If None, no additional balancing
|
| 450 |
+
- If provided, applies per-class weighting like weighted CE
|
| 451 |
+
|
| 452 |
+
Returns:
|
| 453 |
+
Scalar loss value
|
| 454 |
+
|
| 455 |
+
Formula:
|
| 456 |
+
FL = -α * (1 - p_t)^γ * log(p_t)
|
| 457 |
+
where:
|
| 458 |
+
- p_t is probability of correct class
|
| 459 |
+
- (1 - p_t)^γ is modulating factor (focal term)
|
| 460 |
+
- α is class balancing weight
|
| 461 |
+
"""
|
| 462 |
+
# Clip predictions to avoid log(0)
|
| 463 |
+
y_pred = tf.clip_by_value(y_pred, 1e-7, 1.0 - 1e-7)
|
| 464 |
+
|
| 465 |
+
# Probability of correct class at each pixel
|
| 466 |
+
# y_true is one-hot, so this extracts p for the true class
|
| 467 |
+
p_t = tf.reduce_sum(y_true * y_pred, axis=-1)
|
| 468 |
+
# Shape: (bs, H, W)
|
| 469 |
+
|
| 470 |
+
# Focal term: (1 - p_t)^gamma
|
| 471 |
+
# This is small for easy examples (p_t ≈ 1) and large for hard examples (p_t ≈ 0)
|
| 472 |
+
focal_term = tf.pow(1.0 - p_t, gamma)
|
| 473 |
+
# Shape: (bs, H, W)
|
| 474 |
+
|
| 475 |
+
# Cross-entropy term: -log(p_t)
|
| 476 |
+
ce_term = -tf.math.log(p_t)
|
| 477 |
+
# Shape: (bs, H, W)
|
| 478 |
+
|
| 479 |
+
# Focal loss: focal_term * ce_term
|
| 480 |
+
focal_loss = focal_term * ce_term
|
| 481 |
+
# Shape: (bs, H, W)
|
| 482 |
+
|
| 483 |
+
# Optional: Apply alpha balancing (per-class weights)
|
| 484 |
+
if alpha is not None:
|
| 485 |
+
# Get weight for true class at each pixel
|
| 486 |
+
weights_tensor = tf.cast(alpha, dtype=tf.float32)
|
| 487 |
+
weights_tensor = tf.reshape(weights_tensor, [1, 1, 1, -1])
|
| 488 |
+
alpha_map = tf.reduce_sum(y_true * weights_tensor, axis=-1)
|
| 489 |
+
# Shape: (bs, H, W)
|
| 490 |
+
|
| 491 |
+
# Weighted focal
|
| 492 |
+
# Exclude specific class if specified
|
| 493 |
+
if exclude_class is not None:
|
| 494 |
+
class_mask = tf.argmax(y_true, axis=-1) # (bs, 256, 256)
|
| 495 |
+
valid_mask = tf.cast(class_mask != exclude_class, tf.float32)
|
| 496 |
+
|
| 497 |
+
if alpha is not None:
|
| 498 |
+
focal_loss = alpha_map * focal_loss * valid_mask
|
| 499 |
+
else:
|
| 500 |
+
focal_loss = focal_loss * valid_mask
|
| 501 |
+
|
| 502 |
+
return tf.reduce_sum(focal_loss) / (tf.reduce_sum(valid_mask) + 1e-7)
|
| 503 |
+
else:
|
| 504 |
+
|
| 505 |
+
if alpha is not None:
|
| 506 |
+
focal_loss = alpha_map * focal_loss
|
| 507 |
+
|
| 508 |
+
return tf.reduce_mean(focal_loss)
|
| 509 |
+
|
| 510 |
+
def unified_focal_dice_loss(y_true, y_pred, gamma=0.5, delta=0.6, alpha=None, exclude_class=None):
|
| 511 |
+
"""
|
| 512 |
+
Unified Focal Loss - Dice-based
|
| 513 |
+
|
| 514 |
+
Combines Dice coefficient with precision-recall focal weighting.
|
| 515 |
+
Best for imbalanced multi-class segmentation with small structures.
|
| 516 |
+
|
| 517 |
+
Args:
|
| 518 |
+
y_true: Ground truth one-hot (bs, H, W, num_classes)
|
| 519 |
+
y_pred: Predicted probabilities (bs, H, W, num_classes)
|
| 520 |
+
gamma: Focusing parameter for Dice component (default 0.5)
|
| 521 |
+
- gamma=0: equivalent to Dice loss
|
| 522 |
+
- gamma>0: focuses on hard examples
|
| 523 |
+
delta: Weight for precision-recall component (0-1, default 0.6)
|
| 524 |
+
- Controls emphasis on boundary regions
|
| 525 |
+
alpha: Per-class weights (num_classes,) - optional
|
| 526 |
+
exclude_class: Class index to exclude from loss
|
| 527 |
+
|
| 528 |
+
Returns:
|
| 529 |
+
Scalar loss value
|
| 530 |
+
|
| 531 |
+
Formula:
|
| 532 |
+
UFL = (1 - Dice)^gamma * (1 - precision * recall)^delta
|
| 533 |
+
Focuses on hard examples and boundary regions
|
| 534 |
+
"""
|
| 535 |
+
smooth = 1e-6
|
| 536 |
+
y_pred = tf.clip_by_value(y_pred, 1e-7, 1.0 - 1e-7)
|
| 537 |
+
num_classes = tf.shape(y_pred)[-1]
|
| 538 |
+
|
| 539 |
+
unified_losses = []
|
| 540 |
+
|
| 541 |
+
for class_idx in range(num_classes if isinstance(num_classes, int) else y_pred.shape[-1]):
|
| 542 |
+
# Skip excluded class
|
| 543 |
+
if exclude_class is not None and class_idx == exclude_class:
|
| 544 |
+
continue
|
| 545 |
+
|
| 546 |
+
y_true_class = y_true[..., class_idx]
|
| 547 |
+
y_pred_class = y_pred[..., class_idx]
|
| 548 |
+
|
| 549 |
+
# Flatten for calculations
|
| 550 |
+
y_true_f = tf.reshape(y_true_class, [-1])
|
| 551 |
+
y_pred_f = tf.reshape(y_pred_class, [-1])
|
| 552 |
+
|
| 553 |
+
# True positives, false positives, false negatives
|
| 554 |
+
tp = tf.reduce_sum(y_true_f * y_pred_f)
|
| 555 |
+
fp = tf.reduce_sum((1.0 - y_true_f) * y_pred_f)
|
| 556 |
+
fn = tf.reduce_sum(y_true_f * (1.0 - y_pred_f))
|
| 557 |
+
|
| 558 |
+
# Precision and recall
|
| 559 |
+
precision = (tp + smooth) / (tp + fp + smooth)
|
| 560 |
+
recall = (tp + smooth) / (tp + fn + smooth)
|
| 561 |
+
|
| 562 |
+
# Dice coefficient
|
| 563 |
+
dice = (2.0 * tp + smooth) / (2.0 * tp + fp + fn + smooth)
|
| 564 |
+
|
| 565 |
+
# Unified focal loss: focuses on hard examples and boundary regions
|
| 566 |
+
# (1 - dice)^gamma: focuses on classes with low Dice (hard examples)
|
| 567 |
+
# (1 - precision * recall)^delta: focuses on boundary regions
|
| 568 |
+
unified_loss_class = tf.pow(1.0 - dice, gamma) * tf.pow(1.0 - precision * recall, delta)
|
| 569 |
+
|
| 570 |
+
# Apply class weights
|
| 571 |
+
if alpha is not None:
|
| 572 |
+
unified_loss_class = unified_loss_class * tf.cast(alpha[class_idx], tf.float32)
|
| 573 |
+
|
| 574 |
+
unified_losses.append(unified_loss_class)
|
| 575 |
+
|
| 576 |
+
# Stack and mean across classes (excluding the skipped class)
|
| 577 |
+
total_loss = tf.reduce_mean(tf.stack(unified_losses))
|
| 578 |
+
|
| 579 |
+
return total_loss
|
| 580 |
+
|
| 581 |
+
|
| 582 |
+
def weighted_categorical_crossentropy(y_true, y_pred, class_weights, exclude_class=None):
|
| 583 |
+
"""
|
| 584 |
+
Weighted categorical cross-entropy loss
|
| 585 |
+
|
| 586 |
+
Args:
|
| 587 |
+
y_true: (bs, 256, 256, num_classes) one-hot encoded
|
| 588 |
+
y_pred: (bs, 256, 256, num_classes) softmax probabilities
|
| 589 |
+
class_weights: (num_classes,) weight per class
|
| 590 |
+
exclude_class: Optional int, class index to exclude from loss (e.g., 2 for CSF)
|
| 591 |
+
|
| 592 |
+
Returns:
|
| 593 |
+
Scalar loss value
|
| 594 |
+
"""
|
| 595 |
+
# Clip predictions to prevent log(0)
|
| 596 |
+
y_pred = tf.clip_by_value(y_pred, 1e-7, 1.0 - 1e-7)
|
| 597 |
+
|
| 598 |
+
# Cross-entropy per pixel: -sum(y_true * log(y_pred))
|
| 599 |
+
ce = -tf.reduce_sum(y_true * tf.math.log(y_pred), axis=-1) # (bs, 256, 256)
|
| 600 |
+
|
| 601 |
+
# Apply class weights
|
| 602 |
+
# class_weights shape: (num_classes,) -> (1, 1, 1, num_classes) for broadcasting
|
| 603 |
+
weights_tensor = tf.cast(class_weights, dtype=tf.float32)
|
| 604 |
+
weights_tensor = tf.reshape(weights_tensor, [1, 1, 1, -1])
|
| 605 |
+
|
| 606 |
+
# Weight map: (bs, 256, 256)
|
| 607 |
+
pixel_weights = tf.reduce_sum(y_true * weights_tensor, axis=-1)
|
| 608 |
+
|
| 609 |
+
# Weighted cross-entropy
|
| 610 |
+
# Exclude specific class if specified
|
| 611 |
+
if exclude_class is not None:
|
| 612 |
+
class_mask = tf.argmax(y_true, axis=-1) # (bs, 256, 256)
|
| 613 |
+
valid_mask = tf.cast(class_mask != exclude_class, tf.float32)
|
| 614 |
+
weighted_ce = ce * pixel_weights * valid_mask
|
| 615 |
+
return tf.reduce_sum(weighted_ce) / (tf.reduce_sum(valid_mask) + 1e-7)
|
| 616 |
+
else:
|
| 617 |
+
weighted_ce = ce * pixel_weights
|
| 618 |
+
return tf.reduce_mean(weighted_ce)
|
| 619 |
+
|
| 620 |
+
# Combined Adaptive Loss #
|
| 621 |
+
|
| 622 |
+
def adaptive_segmentation_loss(y_true, y_pred, class_weights, beta,
|
| 623 |
+
focal_gamma=0.5, use_focal_alpha=True,
|
| 624 |
+
exclude_class=None):
|
| 625 |
+
"""
|
| 626 |
+
Adaptive segmentation loss with smooth phase transition
|
| 627 |
+
|
| 628 |
+
Combines weighted cross-entropy (phase 1) and focal loss (phase 2)
|
| 629 |
+
based on beta parameter.
|
| 630 |
+
|
| 631 |
+
Args:
|
| 632 |
+
y_true: Ground truth (bs, H, W, num_classes) one-hot
|
| 633 |
+
y_pred: Predictions (bs, H, W, num_classes) softmax probabilities
|
| 634 |
+
class_weights: Trainable class weights (num_classes,)
|
| 635 |
+
beta: Transition parameter [0, 1]
|
| 636 |
+
- beta=0: pure weighted CE (early training)
|
| 637 |
+
- beta=1: pure focal loss (late training)
|
| 638 |
+
focal_gamma: Focusing parameter for focal loss (default 0.5)
|
| 639 |
+
use_focal_alpha: Whether to use class_weights as focal alpha
|
| 640 |
+
|
| 641 |
+
Returns:
|
| 642 |
+
seg_loss: Combined loss
|
| 643 |
+
wcce_loss: Weighted CE component (for monitoring)
|
| 644 |
+
focal_loss: Focal loss component (for monitoring)
|
| 645 |
+
|
| 646 |
+
Phase Behavior:
|
| 647 |
+
Epochs 1-10: beta ≈ 0 → Weighted CE dominates
|
| 648 |
+
- Learns basic class separation
|
| 649 |
+
- Benefits from explicit class weighting
|
| 650 |
+
|
| 651 |
+
Epochs 10-20: beta transitions 0 → 1
|
| 652 |
+
- Smooth change in loss landscape
|
| 653 |
+
- Gradual shift in training dynamics
|
| 654 |
+
|
| 655 |
+
Epochs 20-60: beta ≈ 1 → Focal loss dominates
|
| 656 |
+
- Focuses on hard examples
|
| 657 |
+
- Refines boundaries and difficult regions
|
| 658 |
+
"""
|
| 659 |
+
# Compute Phase 1 loss: Weighted Cross-Entropy
|
| 660 |
+
wcce_loss = weighted_categorical_crossentropy(y_true, y_pred, class_weights, exclude_class=exclude_class)
|
| 661 |
+
|
| 662 |
+
# Compute Phase 2 loss: Focal Loss
|
| 663 |
+
focal_alpha = class_weights if use_focal_alpha else None
|
| 664 |
+
focal_loss = unified_focal_dice_loss(y_true, y_pred,
|
| 665 |
+
gamma=focal_gamma,
|
| 666 |
+
alpha=focal_alpha,
|
| 667 |
+
exclude_class=exclude_class)
|
| 668 |
+
|
| 669 |
+
# Adaptive combination based on beta
|
| 670 |
+
# beta=0: (1-0)*wce + 0*focal = wce (phase 1)
|
| 671 |
+
# beta=1: (1-1)*wce + 1*focal = focal (phase 2)
|
| 672 |
+
# beta=0.5: 0.5*wce + 0.5*focal = equal mix (transition)
|
| 673 |
+
seg_loss = (1.0 - beta) * wcce_loss + beta * focal_loss
|
| 674 |
+
|
| 675 |
+
return seg_loss, wcce_loss, focal_loss
|
| 676 |
+
|
| 677 |
+
|
| 678 |
+
# Binary cross-entropy for GAN loss
|
| 679 |
+
bce_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)
|
| 680 |
+
|
| 681 |
+
|
| 682 |
+
def generator_loss(disc_generated_output, gen_output, target_onehot,
|
| 683 |
+
class_weights, beta, lambda_gan=1, lambda_seg=100,
|
| 684 |
+
focal_gamma=2.0, use_focal_alpha=True):
|
| 685 |
+
"""
|
| 686 |
+
Generator loss: GAN loss + Weighted CCE
|
| 687 |
+
|
| 688 |
+
Args:
|
| 689 |
+
disc_generated_output: Discriminator output for generated mask
|
| 690 |
+
gen_output: Generated mask (bs, 256, 256, num_classes) softmax
|
| 691 |
+
target_onehot: Target mask (bs, 256, 256, num_classes) one-hot
|
| 692 |
+
class_weights: (num_classes,) weight per class
|
| 693 |
+
beta: Phase transition parameter [0, 1]
|
| 694 |
+
lambda_gan: Weight for GAN loss (default 1.0)
|
| 695 |
+
lambda_seg: Weight for segmentation loss (default 100.0)
|
| 696 |
+
focal_gamma: Focal loss focusing parameter (default 2.0)
|
| 697 |
+
use_focal_alpha: Whether to use class weights in focal loss
|
| 698 |
+
|
| 699 |
+
Returns:
|
| 700 |
+
total_gen_loss, gan_loss, seg_loss
|
| 701 |
+
"""
|
| 702 |
+
# GAN loss: fool the discriminator
|
| 703 |
+
gan_loss = bce_loss(
|
| 704 |
+
tf.ones_like(disc_generated_output),
|
| 705 |
+
disc_generated_output
|
| 706 |
+
)
|
| 707 |
+
|
| 708 |
+
# Weighted categorical cross-entropy
|
| 709 |
+
# seg_loss = weighted_categorical_crossentropy(target_onehot, gen_output, class_weights)
|
| 710 |
+
seg_loss, wcce_loss, focal_loss = adaptive_segmentation_loss(target_onehot, gen_output, class_weights, beta,
|
| 711 |
+
focal_gamma=focal_gamma, use_focal_alpha=True)
|
| 712 |
+
|
| 713 |
+
# Total generator loss
|
| 714 |
+
total_gen_loss = (lambda_gan * gan_loss) + (lambda_seg * seg_loss)
|
| 715 |
+
|
| 716 |
+
return total_gen_loss, gan_loss, seg_loss, wcce_loss, focal_loss
|
| 717 |
+
|
| 718 |
+
|
| 719 |
+
def discriminator_loss(disc_real_output, disc_generated_output):
|
| 720 |
+
"""
|
| 721 |
+
Discriminator loss: distinguish real from fake
|
| 722 |
+
|
| 723 |
+
Args:
|
| 724 |
+
disc_real_output: Discriminator output for real mask
|
| 725 |
+
disc_generated_output: Discriminator output for generated mask
|
| 726 |
+
|
| 727 |
+
Returns:
|
| 728 |
+
total_disc_loss
|
| 729 |
+
"""
|
| 730 |
+
real_loss = bce_loss(
|
| 731 |
+
tf.ones_like(disc_real_output),
|
| 732 |
+
disc_real_output
|
| 733 |
+
)
|
| 734 |
+
|
| 735 |
+
generated_loss = bce_loss(
|
| 736 |
+
tf.zeros_like(disc_generated_output),
|
| 737 |
+
disc_generated_output
|
| 738 |
+
)
|
| 739 |
+
|
| 740 |
+
total_disc_loss = real_loss + generated_loss
|
| 741 |
+
|
| 742 |
+
return total_disc_loss
|
| 743 |
+
|
| 744 |
+
|
| 745 |
+
###################### Training Functions ######################
|
| 746 |
+
|
| 747 |
+
@tf.function
|
| 748 |
+
def train_step(input_image, target_onehot, generator, discriminator,
|
| 749 |
+
generator_optimizer, discriminator_optimizer,
|
| 750 |
+
class_weights_np, beta_value,
|
| 751 |
+
lambda_gan, lambda_seg, focal_gamma, use_focal_alpha):
|
| 752 |
+
"""
|
| 753 |
+
Single training step
|
| 754 |
+
|
| 755 |
+
Args:
|
| 756 |
+
input_image: Input FLAIR (bs, 256, 256, 1) in [-1, 1]
|
| 757 |
+
target_onehot: Target mask (bs, 256, 256, num_classes) one-hot
|
| 758 |
+
generator, discriminator, optimizers
|
| 759 |
+
class_weights: (num_classes,) weight per class
|
| 760 |
+
beta_value: Current beta for phase transition
|
| 761 |
+
lambda_gan, lambda_seg: Loss weights
|
| 762 |
+
focal_gamma: Focal loss parameter
|
| 763 |
+
use_focal_alpha: Whether to use class weights in focal
|
| 764 |
+
|
| 765 |
+
Returns:
|
| 766 |
+
gen_total_loss, gen_gan_loss, gen_seg_loss, disc_loss
|
| 767 |
+
"""
|
| 768 |
+
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
|
| 769 |
+
# Generate output
|
| 770 |
+
gen_output = generator(input_image, training=True)
|
| 771 |
+
|
| 772 |
+
# Discriminator outputs
|
| 773 |
+
disc_real_output = discriminator(
|
| 774 |
+
[input_image, target_onehot], training=True
|
| 775 |
+
)
|
| 776 |
+
disc_generated_output = discriminator(
|
| 777 |
+
[input_image, gen_output], training=True
|
| 778 |
+
)
|
| 779 |
+
|
| 780 |
+
# Generator loss (adaptive)
|
| 781 |
+
gen_total_loss, gen_gan_loss, gen_seg_loss, gen_wce_loss, gen_focal_loss = \
|
| 782 |
+
generator_loss(
|
| 783 |
+
disc_generated_output, gen_output, target_onehot,
|
| 784 |
+
class_weights_np, beta_value, lambda_gan, lambda_seg,
|
| 785 |
+
focal_gamma, use_focal_alpha
|
| 786 |
+
)
|
| 787 |
+
|
| 788 |
+
# Discriminator loss
|
| 789 |
+
disc_loss = discriminator_loss(disc_real_output, disc_generated_output)
|
| 790 |
+
|
| 791 |
+
# Calculate gradients
|
| 792 |
+
generator_gradients = gen_tape.gradient(
|
| 793 |
+
gen_total_loss, generator.trainable_variables
|
| 794 |
+
)
|
| 795 |
+
discriminator_gradients = disc_tape.gradient(
|
| 796 |
+
disc_loss, discriminator.trainable_variables
|
| 797 |
+
)
|
| 798 |
+
|
| 799 |
+
# Apply gradients
|
| 800 |
+
generator_optimizer.apply_gradients(
|
| 801 |
+
zip(generator_gradients, generator.trainable_variables)
|
| 802 |
+
)
|
| 803 |
+
discriminator_optimizer.apply_gradients(
|
| 804 |
+
zip(discriminator_gradients, discriminator.trainable_variables)
|
| 805 |
+
)
|
| 806 |
+
|
| 807 |
+
# return gen_total_loss, gen_gan_loss, gen_seg_loss, disc_loss
|
| 808 |
+
return (gen_total_loss, gen_gan_loss, gen_seg_loss, gen_wce_loss,
|
| 809 |
+
gen_focal_loss, disc_loss, class_weights_np)
|
| 810 |
+
|
| 811 |
+
|
| 812 |
+
def generate_and_save_images(generator, test_input, test_target,
|
| 813 |
+
epoch, save_path, num_classes):
|
| 814 |
+
"""
|
| 815 |
+
Generate predictions and save visualization
|
| 816 |
+
|
| 817 |
+
Args:
|
| 818 |
+
generator: Generator model
|
| 819 |
+
test_input: Test input image (bs, 256, 512, 1)
|
| 820 |
+
test_target: Test target mask (bs, 256, 256)
|
| 821 |
+
epoch: Current epoch number
|
| 822 |
+
save_path: Path to save figure
|
| 823 |
+
num_classes: Number of classes
|
| 824 |
+
"""
|
| 825 |
+
for ik in range(test_input.numpy().shape[0]):
|
| 826 |
+
# Extract FLAIR
|
| 827 |
+
flair_normalized = test_input[ik, :, :256, :]
|
| 828 |
+
flair_normalized = tf.expand_dims(flair_normalized, axis=0)
|
| 829 |
+
|
| 830 |
+
# Generate prediction
|
| 831 |
+
prediction_softmax = generator(flair_normalized, training=False)
|
| 832 |
+
|
| 833 |
+
# Convert to class labels
|
| 834 |
+
pred_classes = tf.argmax(prediction_softmax, axis=-1).numpy()
|
| 835 |
+
target_mask = test_target[ik].numpy()
|
| 836 |
+
|
| 837 |
+
# Create figure
|
| 838 |
+
plt.figure(figsize=(20, 5))
|
| 839 |
+
|
| 840 |
+
# Input FLAIR
|
| 841 |
+
plt.subplot(1, 5, 1)
|
| 842 |
+
plt.title('Input FLAIR')
|
| 843 |
+
plt.imshow(flair_normalized[0, :, :, 0], cmap='gray')
|
| 844 |
+
plt.axis('off')
|
| 845 |
+
|
| 846 |
+
# Ground truth
|
| 847 |
+
plt.subplot(1, 5, 2)
|
| 848 |
+
plt.title('Ground Truth')
|
| 849 |
+
plt.imshow(target_mask, cmap='jet', vmin=0, vmax=num_classes-1)
|
| 850 |
+
plt.colorbar()
|
| 851 |
+
plt.axis('off')
|
| 852 |
+
|
| 853 |
+
# Prediction
|
| 854 |
+
plt.subplot(1, 5, 3)
|
| 855 |
+
plt.title('Predicted Classes')
|
| 856 |
+
plt.imshow(pred_classes[0], cmap='jet', vmin=0, vmax=num_classes-1)
|
| 857 |
+
plt.colorbar()
|
| 858 |
+
plt.axis('off')
|
| 859 |
+
|
| 860 |
+
# Class probabilities for most confident prediction
|
| 861 |
+
plt.subplot(1, 5, 4)
|
| 862 |
+
plt.title('Max Probability')
|
| 863 |
+
max_prob = tf.reduce_max(prediction_softmax[0], axis=-1).numpy()
|
| 864 |
+
plt.imshow(max_prob, cmap='viridis', vmin=0, vmax=1)
|
| 865 |
+
plt.colorbar()
|
| 866 |
+
plt.axis('off')
|
| 867 |
+
|
| 868 |
+
# Difference map
|
| 869 |
+
plt.subplot(1, 5, 5)
|
| 870 |
+
plt.title('Error Map (Red=Wrong)')
|
| 871 |
+
error_map = (pred_classes[0] != target_mask).astype(float)
|
| 872 |
+
plt.imshow(error_map, cmap='Reds', vmin=0, vmax=1)
|
| 873 |
+
plt.colorbar()
|
| 874 |
+
plt.axis('off')
|
| 875 |
+
|
| 876 |
+
plt.tight_layout()
|
| 877 |
+
plt.savefig(save_path / f'epoch_{epoch:03d}_{ik+1}.png', dpi=300, bbox_inches='tight')
|
| 878 |
+
plt.close()
|
| 879 |
+
|
| 880 |
+
|
| 881 |
+
###################### Main Training Function ######################
|
| 882 |
+
|
| 883 |
+
def train_experiment_with_metrics(config: ExperimentConfig):
|
| 884 |
+
"""
|
| 885 |
+
Main training function for multi-class pix2pix with attention on discriminator and adaptive loss
|
| 886 |
+
|
| 887 |
+
Args:
|
| 888 |
+
config: ExperimentConfig object
|
| 889 |
+
"""
|
| 890 |
+
print("\n" + "="*70)
|
| 891 |
+
print(f"TRAINING EXPERIMENT: {config.exp_name}")
|
| 892 |
+
print("="*70)
|
| 893 |
+
print(f"Variant: {config.variant} (Baseline + AttentionD + Adaptive Loss)")
|
| 894 |
+
print(f"Preprocessing: {config.preprocessing}")
|
| 895 |
+
print(f"Class scenario: {config.class_scenario} ({config.num_classes} classes)")
|
| 896 |
+
print(f"Fold: {config.fold_id}")
|
| 897 |
+
print(f"Epochs: {config.epochs}")
|
| 898 |
+
print(f"Batch size: {config.batch_size}")
|
| 899 |
+
print(f"Loss weights: λ_SEG={config.lambda_seg}, λ_GAN={config.lambda_gan}")
|
| 900 |
+
print(f"Focal gamma: {config.focal_gamma}")
|
| 901 |
+
print(f"Attention weight: {config.attention_weight}")
|
| 902 |
+
print("="*70 + "\n")
|
| 903 |
+
|
| 904 |
+
# Check initial GPU memory
|
| 905 |
+
get_gpu_memory_info()
|
| 906 |
+
|
| 907 |
+
# Initialize data loader
|
| 908 |
+
data_config = DataConfig()
|
| 909 |
+
data_loader = P1DataLoader(data_config)
|
| 910 |
+
|
| 911 |
+
# Load datasets
|
| 912 |
+
print("Loading training data...")
|
| 913 |
+
train_dataset = data_loader.create_dataset_for_fold(
|
| 914 |
+
fold_id=config.fold_id,
|
| 915 |
+
split='train',
|
| 916 |
+
preprocessing=config.preprocessing,
|
| 917 |
+
class_scenario=config.class_scenario,
|
| 918 |
+
batch_size=config.batch_size,
|
| 919 |
+
shuffle=True
|
| 920 |
+
)
|
| 921 |
+
|
| 922 |
+
print("Loading validation data...")
|
| 923 |
+
val_dataset = data_loader.create_dataset_for_fold(
|
| 924 |
+
fold_id=config.fold_id,
|
| 925 |
+
split='val',
|
| 926 |
+
preprocessing=config.preprocessing,
|
| 927 |
+
class_scenario=config.class_scenario,
|
| 928 |
+
batch_size=config.batch_size,
|
| 929 |
+
shuffle=False
|
| 930 |
+
)
|
| 931 |
+
|
| 932 |
+
# Get dataset sizes
|
| 933 |
+
# Note: from_generator pipelines always report cardinality as INFINITE (-1)
|
| 934 |
+
# even with .cache(), so we derive the batch count from the slice list instead.
|
| 935 |
+
# We iterate once here; this also warms the in-memory cache so epoch 1 is fast.
|
| 936 |
+
print("Warming dataset cache (first pass over data — subsequent epochs use RAM)...")
|
| 937 |
+
train_size = sum(1 for _ in train_dataset)
|
| 938 |
+
val_size = sum(1 for _ in val_dataset)
|
| 939 |
+
# ⚠️ Do NOT rebuild the datasets here — that would create new generators and
|
| 940 |
+
# throw away the cache we just populated.
|
| 941 |
+
|
| 942 |
+
print(f"Training samples (batches): {train_size}")
|
| 943 |
+
print(f"Validation samples (batches): {val_size}\n")
|
| 944 |
+
|
| 945 |
+
# Compute or load class weights
|
| 946 |
+
print("Computing class weights from training data...")
|
| 947 |
+
try:
|
| 948 |
+
class_weights = load_class_weights(
|
| 949 |
+
config.fold_id, config.class_scenario,
|
| 950 |
+
config.preprocessing, config.weights_dir
|
| 951 |
+
)
|
| 952 |
+
print("✅ Loaded pre-computed class weights")
|
| 953 |
+
except FileNotFoundError:
|
| 954 |
+
print("Computing class weights (this may take a few minutes)...")
|
| 955 |
+
results = compute_and_save_class_weights(
|
| 956 |
+
config.fold_id, config.class_scenario,
|
| 957 |
+
config.preprocessing, str(config.weights_dir)
|
| 958 |
+
)
|
| 959 |
+
class_weights = np.array(results['class_weights'], dtype=np.float32)
|
| 960 |
+
|
| 961 |
+
print(f"Class weights: {class_weights}")
|
| 962 |
+
|
| 963 |
+
# Build models
|
| 964 |
+
print("\n🏗️ Building models...")
|
| 965 |
+
generator = build_unet_3class(input_shape=(256, 256, 1), num_classes=config.num_classes)
|
| 966 |
+
discriminator = build_attention_discriminator(
|
| 967 |
+
config.num_classes,
|
| 968 |
+
input_channels=1,
|
| 969 |
+
attention_weight=config.attention_weight,
|
| 970 |
+
use_groupnorm=True # ✅ Consistent with generator
|
| 971 |
+
)
|
| 972 |
+
|
| 973 |
+
print(f"Generator parameters: {generator.count_params():,}")
|
| 974 |
+
print(f"Discriminator parameters: {discriminator.count_params():,}\n")
|
| 975 |
+
|
| 976 |
+
# Optimizers
|
| 977 |
+
generator_optimizer = tf.keras.optimizers.legacy.Adam(
|
| 978 |
+
config.learning_rate, beta_1=config.beta_1
|
| 979 |
+
)
|
| 980 |
+
discriminator_optimizer = tf.keras.optimizers.legacy.Adam(
|
| 981 |
+
config.learning_rate, beta_1=config.beta_1
|
| 982 |
+
)
|
| 983 |
+
|
| 984 |
+
# Initialize optimizer variables
|
| 985 |
+
# CRITICAL: Build optimizer variables by calling them once with dummy data
|
| 986 |
+
# This prevents the "tf.function only supports singleton tf.Variables" error
|
| 987 |
+
print("Initializing optimizer variables...")
|
| 988 |
+
dummy_input = tf.zeros((1, 256, 256, 1))
|
| 989 |
+
dummy_target = tf.zeros((1, 256, 256, config.num_classes))
|
| 990 |
+
|
| 991 |
+
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
|
| 992 |
+
gen_output = generator(dummy_input, training=True)
|
| 993 |
+
disc_output = discriminator([dummy_input, dummy_target], training=True)
|
| 994 |
+
# Dummy losses
|
| 995 |
+
dummy_gen_loss = tf.reduce_mean(gen_output)
|
| 996 |
+
dummy_disc_loss = tf.reduce_mean(disc_output)
|
| 997 |
+
|
| 998 |
+
# Apply dummy gradients to build optimizer variables
|
| 999 |
+
# Don't include class_weights since they're not trainable
|
| 1000 |
+
gen_grads = gen_tape.gradient(dummy_gen_loss, generator.trainable_variables)
|
| 1001 |
+
disc_grads = disc_tape.gradient(dummy_disc_loss, discriminator.trainable_variables)
|
| 1002 |
+
|
| 1003 |
+
generator_optimizer.apply_gradients(zip(gen_grads, generator.trainable_variables))
|
| 1004 |
+
discriminator_optimizer.apply_gradients(zip(disc_grads, discriminator.trainable_variables))
|
| 1005 |
+
print("✅ Optimizer variables initialized\n")
|
| 1006 |
+
|
| 1007 |
+
# Checkpoint
|
| 1008 |
+
checkpoint = tf.train.Checkpoint(
|
| 1009 |
+
generator_optimizer=generator_optimizer,
|
| 1010 |
+
discriminator_optimizer=discriminator_optimizer,
|
| 1011 |
+
generator=generator,
|
| 1012 |
+
discriminator=discriminator
|
| 1013 |
+
)
|
| 1014 |
+
|
| 1015 |
+
checkpoint_prefix = config.checkpoint_dir / "ckpt"
|
| 1016 |
+
manager = tf.train.CheckpointManager(
|
| 1017 |
+
checkpoint, config.checkpoint_dir, max_to_keep=1
|
| 1018 |
+
)
|
| 1019 |
+
|
| 1020 |
+
if manager.latest_checkpoint:
|
| 1021 |
+
checkpoint.restore(manager.latest_checkpoint)
|
| 1022 |
+
print(f"✅ Restored from checkpoint: {manager.latest_checkpoint}\n")
|
| 1023 |
+
else:
|
| 1024 |
+
print("Starting training from scratch\n")
|
| 1025 |
+
|
| 1026 |
+
# Load pretrained models:
|
| 1027 |
+
generator_weights_path = f"{config.checkpoint_dir}/best_dice_generator.h5"
|
| 1028 |
+
if os.path.isfile(generator_weights_path):
|
| 1029 |
+
generator.load_weights(generator_weights_path)
|
| 1030 |
+
|
| 1031 |
+
discriminator_weights_path = f"{config.checkpoint_dir}/best_dice_discriminator.h5"
|
| 1032 |
+
if os.path.isfile(discriminator_weights_path):
|
| 1033 |
+
discriminator.load_weights(discriminator_weights_path)
|
| 1034 |
+
|
| 1035 |
+
# Get example for visualization
|
| 1036 |
+
skip_n = 1 # min(100 // config.batch_size, val_size - 1)
|
| 1037 |
+
example_paired, example_target, _, _ = next(iter(val_dataset.skip(skip_n).take(20)))
|
| 1038 |
+
|
| 1039 |
+
print("Initializing metrics computer...")
|
| 1040 |
+
if config.num_classes == 2:
|
| 1041 |
+
class_names = ['Background', 'Specialized_GM']
|
| 1042 |
+
else:
|
| 1043 |
+
raise FileNotFoundError
|
| 1044 |
+
|
| 1045 |
+
# Training history
|
| 1046 |
+
history = {
|
| 1047 |
+
'gen_total_loss': [],
|
| 1048 |
+
'gen_gan_loss': [],
|
| 1049 |
+
'gen_seg_loss': [],
|
| 1050 |
+
'gen_wce_loss': [],
|
| 1051 |
+
'gen_focal_loss': [],
|
| 1052 |
+
'disc_loss': [],
|
| 1053 |
+
'val_loss': [],
|
| 1054 |
+
'beta_value': [],
|
| 1055 |
+
'val_metrics': []
|
| 1056 |
+
}
|
| 1057 |
+
|
| 1058 |
+
# Training loop
|
| 1059 |
+
best_val_loss = float('inf')
|
| 1060 |
+
best_val_dice = 0.0
|
| 1061 |
+
exclude_class = None # Exclude class !
|
| 1062 |
+
|
| 1063 |
+
try:
|
| 1064 |
+
for epoch in range(config.epochs):
|
| 1065 |
+
start_time = time.time()
|
| 1066 |
+
|
| 1067 |
+
# Compute beta for this epoch
|
| 1068 |
+
beta_value = compute_beta_schedule(
|
| 1069 |
+
epoch, config.epochs,
|
| 1070 |
+
config.beta_threshold, config.beta_smoothness
|
| 1071 |
+
)
|
| 1072 |
+
|
| 1073 |
+
# Training metrics
|
| 1074 |
+
epoch_gen_total_loss = []
|
| 1075 |
+
epoch_gen_gan_loss = []
|
| 1076 |
+
epoch_gen_seg_loss = []
|
| 1077 |
+
epoch_gen_wce_loss = []
|
| 1078 |
+
epoch_gen_focal_loss = []
|
| 1079 |
+
epoch_disc_loss = []
|
| 1080 |
+
|
| 1081 |
+
# Training loop
|
| 1082 |
+
|
| 1083 |
+
# Update learning rate based on epoch
|
| 1084 |
+
new_lr_1 = config.learning_rate * ((1-(7/8)*beta_value)) # Exponential decay based on beta (based on switching on focal loss)
|
| 1085 |
+
new_lr_2 = config.learning_rate * ((1-(1-0.5e-2)*(epoch / config.epochs))) # Steadily decay from 2e-4 to 1e-6
|
| 1086 |
+
new_lr = min(new_lr_1, new_lr_2)
|
| 1087 |
+
generator_optimizer.learning_rate.assign(new_lr)
|
| 1088 |
+
discriminator_optimizer.learning_rate.assign(new_lr)
|
| 1089 |
+
|
| 1090 |
+
lambda_GAN = config.lambda_gan*(1-beta_value.numpy()).astype(np.float64)
|
| 1091 |
+
print(f"\nEpoch {epoch+1}/{config.epochs} (β={beta_value.numpy():.4f}) (λ_GAN={lambda_GAN:.4f}) (lr={new_lr:.6f})")
|
| 1092 |
+
train_bar = tqdm(train_dataset, total=train_size, desc="Training")
|
| 1093 |
+
|
| 1094 |
+
for paired_input, target_mask, patient_id_tensor, slice_num_tensor in train_bar:
|
| 1095 |
+
|
| 1096 |
+
patient_id = patient_id_tensor.numpy()[0].decode('utf-8') # batch dim + bytes→str
|
| 1097 |
+
slice_num = int(slice_num_tensor.numpy()[0])
|
| 1098 |
+
|
| 1099 |
+
# ✅ Prepare inputs: normalize FLAIR + one-hot encode target
|
| 1100 |
+
flair_normalized, target_onehot = prepare_inputs(
|
| 1101 |
+
paired_input, target_mask, config.num_classes
|
| 1102 |
+
)
|
| 1103 |
+
|
| 1104 |
+
# Train step
|
| 1105 |
+
gen_total, gen_gan, gen_seg, gen_wce, gen_focal, disc, cw = train_step(
|
| 1106 |
+
flair_normalized, target_onehot,
|
| 1107 |
+
generator, discriminator,
|
| 1108 |
+
generator_optimizer, discriminator_optimizer,
|
| 1109 |
+
class_weights, beta_value,
|
| 1110 |
+
config.lambda_gan, config.lambda_seg,
|
| 1111 |
+
config.focal_gamma, config.use_focal_alpha
|
| 1112 |
+
)
|
| 1113 |
+
|
| 1114 |
+
epoch_gen_total_loss.append(gen_total.numpy())
|
| 1115 |
+
epoch_gen_gan_loss.append(gen_gan.numpy())
|
| 1116 |
+
epoch_gen_seg_loss.append(gen_seg.numpy())
|
| 1117 |
+
epoch_gen_wce_loss.append(gen_wce.numpy())
|
| 1118 |
+
epoch_gen_focal_loss.append(gen_focal.numpy())
|
| 1119 |
+
epoch_disc_loss.append(disc.numpy())
|
| 1120 |
+
|
| 1121 |
+
# Update progress bar
|
| 1122 |
+
train_bar.set_postfix({
|
| 1123 |
+
'G_loss': f"{gen_total.numpy():.4f}",
|
| 1124 |
+
'D_loss': f"{disc.numpy():.4f}",
|
| 1125 |
+
'SEG': f"{gen_seg.numpy():.4f}"
|
| 1126 |
+
})
|
| 1127 |
+
|
| 1128 |
+
# Calculate epoch averages
|
| 1129 |
+
avg_gen_total = np.mean(epoch_gen_total_loss)
|
| 1130 |
+
avg_gen_gan = np.mean(epoch_gen_gan_loss)
|
| 1131 |
+
avg_gen_seg = np.mean(epoch_gen_seg_loss)
|
| 1132 |
+
avg_gen_wce = np.mean(epoch_gen_wce_loss)
|
| 1133 |
+
avg_gen_focal = np.mean(epoch_gen_focal_loss)
|
| 1134 |
+
avg_disc = np.mean(epoch_disc_loss)
|
| 1135 |
+
|
| 1136 |
+
history['gen_total_loss'].append(avg_gen_total)
|
| 1137 |
+
history['gen_gan_loss'].append(avg_gen_gan)
|
| 1138 |
+
history['gen_seg_loss'].append(avg_gen_seg)
|
| 1139 |
+
history['gen_wce_loss'].append(avg_gen_wce)
|
| 1140 |
+
history['gen_focal_loss'].append(avg_gen_focal)
|
| 1141 |
+
history['disc_loss'].append(avg_disc)
|
| 1142 |
+
history['beta_value'].append(float(beta_value.numpy()))
|
| 1143 |
+
|
| 1144 |
+
# Validation
|
| 1145 |
+
val_losses = []
|
| 1146 |
+
all_val_true = []
|
| 1147 |
+
all_val_pred = []
|
| 1148 |
+
|
| 1149 |
+
for val_paired, val_target, patient_id_tensor, slice_num_tensor in val_dataset:
|
| 1150 |
+
try:
|
| 1151 |
+
|
| 1152 |
+
patient_id = patient_id_tensor.numpy()[0].decode('utf-8') # batch dim + bytes→str
|
| 1153 |
+
slice_num = int(slice_num_tensor.numpy()[0])
|
| 1154 |
+
|
| 1155 |
+
val_flair_norm, val_target_onehot = prepare_inputs(
|
| 1156 |
+
val_paired, val_target, config.num_classes
|
| 1157 |
+
)
|
| 1158 |
+
|
| 1159 |
+
val_pred = generator(val_flair_norm, training=False) # ✅ Now safe!
|
| 1160 |
+
|
| 1161 |
+
val_seg_loss, _, _ = adaptive_segmentation_loss(
|
| 1162 |
+
val_target_onehot, val_pred, class_weights,
|
| 1163 |
+
beta_value, focal_gamma=config.focal_gamma, exclude_class=exclude_class
|
| 1164 |
+
)
|
| 1165 |
+
|
| 1166 |
+
# Store true and prediction values for final metrics calculation
|
| 1167 |
+
all_val_true.append(val_target_onehot)
|
| 1168 |
+
all_val_pred.append(val_pred)
|
| 1169 |
+
|
| 1170 |
+
if not tf.math.is_nan(val_seg_loss):
|
| 1171 |
+
val_losses.append(val_seg_loss.numpy())
|
| 1172 |
+
except:
|
| 1173 |
+
continue
|
| 1174 |
+
|
| 1175 |
+
|
| 1176 |
+
if len(val_losses) > 0:
|
| 1177 |
+
avg_val_loss = np.mean(val_losses)
|
| 1178 |
+
history['val_loss'].append(avg_val_loss)
|
| 1179 |
+
|
| 1180 |
+
# Compute class-wise metrics
|
| 1181 |
+
val_metrics = compute_classwise_metrics(
|
| 1182 |
+
all_val_true, all_val_pred,
|
| 1183 |
+
config.num_classes#, exclude_class=exclude_class
|
| 1184 |
+
)
|
| 1185 |
+
history['val_metrics'].append(val_metrics)
|
| 1186 |
+
|
| 1187 |
+
# Print validation results
|
| 1188 |
+
epoch_time = time.time() - start_time
|
| 1189 |
+
print(f"\n{'='*70}")
|
| 1190 |
+
print(f"Epoch {epoch+1}/{config.epochs} Summary (Time: {epoch_time:.2f}s)")
|
| 1191 |
+
print(f"{'='*70}")
|
| 1192 |
+
print(f"Training Losses:")
|
| 1193 |
+
print(f" Generator Total: {avg_gen_total:.4f} | GAN: {avg_gen_gan:.4f} | SEG: {avg_gen_seg:.4f}")
|
| 1194 |
+
print(f" WCE: {avg_gen_wce:.4f} | Focal: {avg_gen_focal:.4f} | Discriminator: {avg_disc:.4f}")
|
| 1195 |
+
print(f"\nValidation Loss: {avg_val_loss:.4f}")
|
| 1196 |
+
print(f"\nClass-wise Dice Scores:")
|
| 1197 |
+
for class_name, dice_val in val_metrics['dice'].items():
|
| 1198 |
+
if class_name != 'mean':
|
| 1199 |
+
print(f" {class_name}: {dice_val:.4f}")
|
| 1200 |
+
if class_name == f"class_{config.num_classes -1}":
|
| 1201 |
+
gm_val_dice = dice_val
|
| 1202 |
+
print(f" Mean Dice: {val_metrics['dice']['mean']:.4f}")
|
| 1203 |
+
print(f"\nClass-wise Precision:")
|
| 1204 |
+
for class_name, prec_val in val_metrics['precision'].items():
|
| 1205 |
+
if class_name != 'mean':
|
| 1206 |
+
print(f" {class_name}: {prec_val:.4f}")
|
| 1207 |
+
print(f" Mean Precision: {val_metrics['precision']['mean']:.4f}")
|
| 1208 |
+
print(f"\nClass-wise Recall:")
|
| 1209 |
+
for class_name, rec_val in val_metrics['recall'].items():
|
| 1210 |
+
if class_name != 'mean':
|
| 1211 |
+
print(f" {class_name}: {rec_val:.4f}")
|
| 1212 |
+
print(f" Mean Recall: {val_metrics['recall']['mean']:.4f}")
|
| 1213 |
+
print(f"{'='*70}\n")
|
| 1214 |
+
|
| 1215 |
+
# Save best model based on validation loss
|
| 1216 |
+
overal_val_performance = 0.9 * gm_val_dice + 0.1 * (1-10*avg_val_loss)
|
| 1217 |
+
if overal_val_performance > best_val_dice and beta_value.numpy() > 0.9:
|
| 1218 |
+
best_val_dice = overal_val_performance
|
| 1219 |
+
generator.save_weights(f"{config.checkpoint_dir}/best_dice_generator.h5")
|
| 1220 |
+
discriminator.save_weights(f"{config.checkpoint_dir}/best_dice_discriminator.h5")
|
| 1221 |
+
print(f"✓ Best model saved (performance: {best_val_dice:.4f})")
|
| 1222 |
+
else:
|
| 1223 |
+
print("Warning: No valid validation batches")
|
| 1224 |
+
history['val_loss'].append(float('nan'))
|
| 1225 |
+
history['val_metrics'].append({})
|
| 1226 |
+
|
| 1227 |
+
# Print epoch summary
|
| 1228 |
+
epoch_time = time.time() - start_time
|
| 1229 |
+
print(f"Epoch {epoch+1} Summary:")
|
| 1230 |
+
print(f" Gen Total Loss: {avg_gen_total:.4f}")
|
| 1231 |
+
print(f" Gen GAN Loss: {avg_gen_gan:.4f}")
|
| 1232 |
+
print(f" Gen Seg Loss: {avg_gen_seg:.4f}")
|
| 1233 |
+
print(f" - WCE component: {avg_gen_wce:.4f}")
|
| 1234 |
+
print(f" - Focal component: {avg_gen_focal:.4f}")
|
| 1235 |
+
print(f" Disc Loss: {avg_disc:.4f}")
|
| 1236 |
+
print(f" Val Loss: {avg_val_loss:.4f}")
|
| 1237 |
+
print(f" Beta: {beta_value.numpy():.4f}")
|
| 1238 |
+
print(f" Time: {epoch_time:.2f}s")
|
| 1239 |
+
|
| 1240 |
+
# Save checkpoint
|
| 1241 |
+
if (epoch + 1) % 5 == 0 and False:
|
| 1242 |
+
manager.save()
|
| 1243 |
+
print(f" 💾 Saved checkpoint")
|
| 1244 |
+
|
| 1245 |
+
# Generate sample images
|
| 1246 |
+
if (epoch + 1) % 5 == 0 or epoch == 0 or True:
|
| 1247 |
+
generate_and_save_images(
|
| 1248 |
+
generator, example_paired, example_target,
|
| 1249 |
+
epoch + 1, config.figures_dir, config.num_classes
|
| 1250 |
+
)
|
| 1251 |
+
print(f" 📊 Saved visualization")
|
| 1252 |
+
|
| 1253 |
+
# # Save final model
|
| 1254 |
+
# final_model_path = config.checkpoint_dir / "final_model.h5"
|
| 1255 |
+
# generator.save(final_model_path)
|
| 1256 |
+
# print(f"\n✅ Training complete! Final model saved to {final_model_path}")
|
| 1257 |
+
|
| 1258 |
+
# Save history
|
| 1259 |
+
history_serializable = {
|
| 1260 |
+
key: [float(val) if isinstance(val, (int, float, np.number)) else val
|
| 1261 |
+
for val in values]
|
| 1262 |
+
for key, values in history.items()
|
| 1263 |
+
}
|
| 1264 |
+
|
| 1265 |
+
history_file = config.checkpoint_dir / "history.json"
|
| 1266 |
+
with open(history_file, 'w') as f:
|
| 1267 |
+
json.dump(history_serializable, f, indent=2)
|
| 1268 |
+
|
| 1269 |
+
return history, history_file
|
| 1270 |
+
|
| 1271 |
+
finally:
|
| 1272 |
+
# CRITICAL: Always cleanup, even if training fails
|
| 1273 |
+
# This runs whether training succeeds or fails
|
| 1274 |
+
print("\n🧹 Cleaning up resources...")
|
| 1275 |
+
|
| 1276 |
+
# Delete models explicitly to break references
|
| 1277 |
+
try:
|
| 1278 |
+
del generator
|
| 1279 |
+
del discriminator
|
| 1280 |
+
del generator_optimizer
|
| 1281 |
+
del discriminator_optimizer
|
| 1282 |
+
del checkpoint
|
| 1283 |
+
del manager
|
| 1284 |
+
del train_dataset
|
| 1285 |
+
del val_dataset
|
| 1286 |
+
# class_weights don't need deletion (they're constants, not variables)
|
| 1287 |
+
print("✅ Deleted model objects")
|
| 1288 |
+
except Exception as e:
|
| 1289 |
+
print(f"⚠️ Error deleting objects: {e}")
|
| 1290 |
+
|
| 1291 |
+
# Clear GPU memory
|
| 1292 |
+
clear_gpu_memory()
|
| 1293 |
+
|
| 1294 |
+
# Check final GPU memory
|
| 1295 |
+
get_gpu_memory_info()
|
| 1296 |
+
|
| 1297 |
+
|
| 1298 |
+
###################### Main Execution ######################
|
| 1299 |
+
|
| 1300 |
+
if __name__ == "__main__":
|
| 1301 |
+
# Example: Train multi-class model for 4-class, standard preprocessing, fold 0
|
| 1302 |
+
config = ExperimentConfig(
|
| 1303 |
+
variant=1,
|
| 1304 |
+
preprocessing='standard',
|
| 1305 |
+
class_scenario='binary',
|
| 1306 |
+
fold_id=0
|
| 1307 |
+
)
|
| 1308 |
+
|
| 1309 |
+
history, history_path = train_experiment_with_metrics(config)
|
| 1310 |
+
|
| 1311 |
+
print("\n" + "="*70)
|
| 1312 |
+
print("EXPERIMENT COMPLETE")
|
| 1313 |
+
print("="*70)
|
models/for_GM/model_training_scripts/p1_predict_new_data_gm.py
ADDED
|
@@ -0,0 +1,477 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
P1 Article - Prediction Script for New Data (No Ground Truth)
|
| 3 |
+
|
| 4 |
+
Predicts specialized Gray Matter segmentation masks for new HC/MS cohort patients.
|
| 5 |
+
|
| 6 |
+
Outputs per patient:
|
| 7 |
+
- {patient_id}_gm_mask.nii.gz → binary gm mask (class 1)
|
| 8 |
+
|
| 9 |
+
Developer:
|
| 10 |
+
Mahdi Bashiri Bawil
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import tensorflow as tf
|
| 14 |
+
import os
|
| 15 |
+
import numpy as np
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
import nibabel as nib
|
| 19 |
+
import argparse
|
| 20 |
+
|
| 21 |
+
print("TensorFlow Version:", tf.__version__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
###################### GPU Configuration ######################
|
| 25 |
+
|
| 26 |
+
physical_devices = tf.config.list_physical_devices('GPU')
|
| 27 |
+
if physical_devices:
|
| 28 |
+
try:
|
| 29 |
+
for device in physical_devices:
|
| 30 |
+
tf.config.experimental.set_memory_growth(device, True)
|
| 31 |
+
print(f"✅ GPU memory growth enabled ({len(physical_devices)} GPU(s) found)")
|
| 32 |
+
except RuntimeError as e:
|
| 33 |
+
print(f"GPU configuration error: {e}")
|
| 34 |
+
else:
|
| 35 |
+
print("⚠️ No GPU detected – inference will run on CPU")
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
###################### Configuration ######################
|
| 39 |
+
|
| 40 |
+
class PredictConfig:
|
| 41 |
+
"""
|
| 42 |
+
All settings for the new-data prediction pipeline.
|
| 43 |
+
Edit the values in __init__ or pass overrides via the CLI at the bottom.
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
# ── Model settings ──────────────────────────────────────────────────
|
| 49 |
+
variant: int = 1,
|
| 50 |
+
preprocessing: str = "standard",
|
| 51 |
+
class_scenario: str = "binary",
|
| 52 |
+
architecture_name: str = "unet",
|
| 53 |
+
model_name: str = "best_dice_model.h5",
|
| 54 |
+
fold_id: int = 0,
|
| 55 |
+
|
| 56 |
+
# ── Slice range (1-based, inclusive) ────────────────────────────────
|
| 57 |
+
# Only slices within [slice_start, slice_end] are fed to the model.
|
| 58 |
+
# All other slices receive empty (zero) masks.
|
| 59 |
+
slice_start: int = 1,
|
| 60 |
+
slice_end: int = 20,
|
| 61 |
+
|
| 62 |
+
# ── Data root ───────────────────────────────────────────────────────
|
| 63 |
+
data_root: str = "/mnt/d/TEMP_P4",
|
| 64 |
+
|
| 65 |
+
# ── Post-processing ─────────────────────────────────────────────────
|
| 66 |
+
apply_postprocess: bool = False,
|
| 67 |
+
min_object_size: int = 5,
|
| 68 |
+
closing_kernel_size: int = 2,
|
| 69 |
+
):
|
| 70 |
+
# Experiment
|
| 71 |
+
self.variant = variant
|
| 72 |
+
self.fold_id = fold_id
|
| 73 |
+
self.preprocessing = preprocessing
|
| 74 |
+
self.class_scenario = class_scenario
|
| 75 |
+
self.architecture_name = architecture_name
|
| 76 |
+
self.model_name = model_name
|
| 77 |
+
|
| 78 |
+
# Classes
|
| 79 |
+
self.num_classes = 2
|
| 80 |
+
self.class_names = ["Background", "Specialized_GM"]
|
| 81 |
+
|
| 82 |
+
# Image dimensions (must match training)
|
| 83 |
+
self.img_width = 256
|
| 84 |
+
self.img_height = 256
|
| 85 |
+
|
| 86 |
+
# Slice range (1-based, inclusive)
|
| 87 |
+
self.slice_start = slice_start
|
| 88 |
+
self.slice_end = slice_end
|
| 89 |
+
|
| 90 |
+
# Post-processing
|
| 91 |
+
self.apply_postprocess = apply_postprocess
|
| 92 |
+
print(f'\n \t apply_postprocess: {apply_postprocess} \n')
|
| 93 |
+
self.min_object_size = min_object_size
|
| 94 |
+
self.closing_kernel_size = closing_kernel_size
|
| 95 |
+
|
| 96 |
+
# Data root
|
| 97 |
+
self.data_root = Path(data_root)
|
| 98 |
+
|
| 99 |
+
# Cohort sub-directories (relative to data_root)
|
| 100 |
+
self.cohorts = {
|
| 101 |
+
"HC": self.data_root / "HC_COHORT_PREP_prepared" / "FLAIR_Preprocessed",
|
| 102 |
+
"MS": self.data_root / "MS_COHORT_PREP_prepared" / "FLAIR_Preprocessed",
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
# Model path
|
| 106 |
+
self.results_dir = Path(
|
| 107 |
+
f"results_fold_{fold_id}_var_{variant}_bet_zscore_gm" # adjust if you use a single fold
|
| 108 |
+
)
|
| 109 |
+
self.models_dir = self.results_dir / "models" / f"{preprocessing}_{class_scenario}"
|
| 110 |
+
|
| 111 |
+
# ── Print summary ────────────────────────────────────────────────────
|
| 112 |
+
print(f"\n{'='*70}")
|
| 113 |
+
print("PREDICTION CONFIGURATION (New Data)")
|
| 114 |
+
print(f"{'='*70}")
|
| 115 |
+
print(f" Variant : {self.variant}")
|
| 116 |
+
print(f" Fold : {self.fold_id}")
|
| 117 |
+
print(f" Preprocessing : {self.preprocessing}")
|
| 118 |
+
print(f" Class scenario : {self.class_scenario} ({self.num_classes} classes)")
|
| 119 |
+
print(f" Architecture : {self.architecture_name}")
|
| 120 |
+
print(f" Model file : {self.model_name}")
|
| 121 |
+
print(f" Slice range : {self.slice_start} – {self.slice_end} (1-based)")
|
| 122 |
+
print(f" Post-processing : {self.apply_postprocess}")
|
| 123 |
+
print(f" Data root : {self.data_root}")
|
| 124 |
+
print(f"{'='*70}\n")
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
###################### Utility Helpers ######################
|
| 128 |
+
|
| 129 |
+
def load_nifti(path: Path):
|
| 130 |
+
"""Load a NIfTI file and return (numpy_array, nib_image)."""
|
| 131 |
+
img = nib.load(str(path))
|
| 132 |
+
return img.get_fdata(dtype=np.float32), img
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def save_binary_nifti(mask: np.ndarray, save_path: Path, reference_img):
|
| 136 |
+
"""
|
| 137 |
+
Save a binary 3-D mask as a NIfTI file.
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
mask : (H, W, S) or (S, H, W) boolean/uint8 array
|
| 141 |
+
save_path : destination path (*.nii.gz)
|
| 142 |
+
reference_img: nibabel image whose affine/header are reused
|
| 143 |
+
"""
|
| 144 |
+
save_path.parent.mkdir(parents=True, exist_ok=True)
|
| 145 |
+
nifti_out = nib.Nifti1Image(
|
| 146 |
+
mask.astype(np.uint8),
|
| 147 |
+
reference_img.affine,
|
| 148 |
+
reference_img.header,
|
| 149 |
+
)
|
| 150 |
+
nib.save(nifti_out, str(save_path))
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def preprocess_slice(slice_2d: np.ndarray, target_h: int = 256, target_w: int = 256) -> np.ndarray:
|
| 154 |
+
"""
|
| 155 |
+
Resize a 2-D slice to (target_h, target_w) if necessary and
|
| 156 |
+
return a float32 array with shape (1, H, W, 1) ready for the model.
|
| 157 |
+
|
| 158 |
+
The data files are assumed to be already normalised to [0, 1] and
|
| 159 |
+
z-score normalised (as stated in the task description), so no
|
| 160 |
+
additional intensity normalisation is applied here.
|
| 161 |
+
"""
|
| 162 |
+
import cv2 # lightweight resize; falls back to zoom if cv2 unavailable
|
| 163 |
+
|
| 164 |
+
h, w = slice_2d.shape
|
| 165 |
+
if h != target_h or w != target_w:
|
| 166 |
+
slice_2d = cv2.resize(
|
| 167 |
+
slice_2d, (target_w, target_h), interpolation=cv2.INTER_LINEAR
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
# shape → (1, H, W, 1)
|
| 171 |
+
return slice_2d[np.newaxis, :, :, np.newaxis].astype(np.float32)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def post_process_pred(pred_classes: np.ndarray, num_classes: int = 2,
|
| 175 |
+
min_object_size: int = 5, closing_kernel_size: int = 2) -> np.ndarray:
|
| 176 |
+
"""
|
| 177 |
+
Morphological post-processing for a single 2-D prediction slice.
|
| 178 |
+
Identical to the function used during training inference.
|
| 179 |
+
|
| 180 |
+
Pipeline (per foreground class):
|
| 181 |
+
1. Extract binary mask from the label map.
|
| 182 |
+
2. binary_closing – fill small holes / bridge tiny gaps.
|
| 183 |
+
3. remove_small_objects – discard isolated noise specks.
|
| 184 |
+
4. Reconstruct integer label map.
|
| 185 |
+
"""
|
| 186 |
+
from skimage.morphology import remove_small_objects, binary_closing, disk
|
| 187 |
+
|
| 188 |
+
kernel = disk(closing_kernel_size)
|
| 189 |
+
|
| 190 |
+
def clean(mask):
|
| 191 |
+
if not mask.any():
|
| 192 |
+
return mask
|
| 193 |
+
mask = binary_closing(mask, kernel)
|
| 194 |
+
mask = remove_small_objects(mask, min_size=min_object_size)
|
| 195 |
+
return mask
|
| 196 |
+
|
| 197 |
+
gm_mask = (pred_classes == 1)
|
| 198 |
+
|
| 199 |
+
gm_mask = clean(gm_mask)
|
| 200 |
+
|
| 201 |
+
post_pred = np.zeros_like(pred_classes)
|
| 202 |
+
post_pred[gm_mask] = 1
|
| 203 |
+
|
| 204 |
+
return post_pred
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
###################### Model Loading ######################
|
| 208 |
+
|
| 209 |
+
def load_model(config: PredictConfig, fold_id: int):
|
| 210 |
+
"""
|
| 211 |
+
Build the model architecture and load weights for the given fold.
|
| 212 |
+
|
| 213 |
+
Returns the loaded generator (keras Model).
|
| 214 |
+
"""
|
| 215 |
+
if config.architecture_name == "unet":
|
| 216 |
+
from unet_model import build_unet_3class as build_fn
|
| 217 |
+
elif config.architecture_name == "attnunet":
|
| 218 |
+
from attn_unet_model import build_attention_unet_3class as build_fn
|
| 219 |
+
elif config.architecture_name == "dlv3unet":
|
| 220 |
+
from dlv3_unet_model_GN import build_deeplabv3_unet_3class as build_fn
|
| 221 |
+
elif config.architecture_name == "transunet":
|
| 222 |
+
from trans_unet_model import build_trans_unet_3class as build_fn
|
| 223 |
+
else:
|
| 224 |
+
raise ValueError(f"Unknown architecture: {config.architecture_name}")
|
| 225 |
+
|
| 226 |
+
model_path = (
|
| 227 |
+
config.models_dir
|
| 228 |
+
/ f"fold_{fold_id}"
|
| 229 |
+
/ config.model_name
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
if not model_path.exists():
|
| 233 |
+
raise FileNotFoundError(f"Model not found: {model_path}")
|
| 234 |
+
|
| 235 |
+
generator = build_fn(
|
| 236 |
+
input_shape=(config.img_height, config.img_width, 1),
|
| 237 |
+
num_classes=config.num_classes,
|
| 238 |
+
)
|
| 239 |
+
generator.load_weights(str(model_path))
|
| 240 |
+
print(f" ✅ Fold {fold_id} model loaded from: {model_path}")
|
| 241 |
+
return generator
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
###################### Per-Patient Prediction ######################
|
| 245 |
+
|
| 246 |
+
def predict_patient(
|
| 247 |
+
patient_id: str,
|
| 248 |
+
flair_path: Path,
|
| 249 |
+
brain_mask_path: Path,
|
| 250 |
+
models: list, # list of keras generators (one per fold)
|
| 251 |
+
config: PredictConfig,
|
| 252 |
+
gm_out_dir: Path,
|
| 253 |
+
):
|
| 254 |
+
"""
|
| 255 |
+
Run inference for a single patient and save Specialized_GM masks.
|
| 256 |
+
|
| 257 |
+
Steps:
|
| 258 |
+
1. Load FLAIR volume and brain mask.
|
| 259 |
+
2. Apply brain mask (multiply) → brain-extracted volume.
|
| 260 |
+
3. For each slice in [slice_start, slice_end]:
|
| 261 |
+
a. Resize to 256×256.
|
| 262 |
+
b. Run through all fold models and average softmax outputs.
|
| 263 |
+
c. argmax → class label.
|
| 264 |
+
d. Optional post-processing.
|
| 265 |
+
4. Slices outside the range → empty (zero) predictions.
|
| 266 |
+
5. Save: main prediction, Specialized_GM binary mask.
|
| 267 |
+
"""
|
| 268 |
+
# ── Load data ────────────────────────────────────────────────────────────
|
| 269 |
+
flair_data, flair_img = load_nifti(flair_path) # (H, W, S)
|
| 270 |
+
brain_mask, _ = load_nifti(brain_mask_path) # (H, W, S) binary
|
| 271 |
+
|
| 272 |
+
# Brain extraction: zero out non-brain voxels
|
| 273 |
+
brain_mask_bool = brain_mask > 0
|
| 274 |
+
flair_brain = np.copy(flair_data)
|
| 275 |
+
flair_brain[~brain_mask_bool] = np.min(flair_data)
|
| 276 |
+
|
| 277 |
+
# flair_brain = flair_data * brain_mask # (H, W, S)
|
| 278 |
+
|
| 279 |
+
num_slices = flair_brain.shape[2]
|
| 280 |
+
|
| 281 |
+
# Convert to 0-based slice indices for the active range
|
| 282 |
+
# Input: slice_start / slice_end are 1-based (as stated in the task).
|
| 283 |
+
active_start = config.slice_start - 1 # inclusive, 0-based
|
| 284 |
+
active_end = config.slice_end - 1 # inclusive, 0-based
|
| 285 |
+
|
| 286 |
+
# Clamp to actual volume depth
|
| 287 |
+
active_start = max(0, active_start)
|
| 288 |
+
active_end = min(num_slices - 1, active_end)
|
| 289 |
+
|
| 290 |
+
# Initialise output volumes (H, W, S) – same spatial shape as the input
|
| 291 |
+
H, W = flair_brain.shape[0], flair_brain.shape[1]
|
| 292 |
+
pred_volume = np.zeros((H, W, num_slices), dtype=np.uint8) # main prediction
|
| 293 |
+
gm_volume = np.zeros((H, W, num_slices), dtype=np.uint8) # binary Specialized_GM
|
| 294 |
+
|
| 295 |
+
# ── Inference loop ───────────────────────────────────────────────────────
|
| 296 |
+
for s in range(num_slices):
|
| 297 |
+
|
| 298 |
+
if s < active_start or s > active_end:
|
| 299 |
+
# Outside desired range: leave masks empty
|
| 300 |
+
continue
|
| 301 |
+
|
| 302 |
+
slice_2d = flair_brain[:, :, s] # (H, W)
|
| 303 |
+
model_input = preprocess_slice( # (1, 256, 256, 1)
|
| 304 |
+
slice_2d, config.img_height, config.img_width
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
# Ensemble: average softmax probabilities across all fold models
|
| 308 |
+
softmax_sum = np.zeros(
|
| 309 |
+
(1, config.img_height, config.img_width, config.num_classes),
|
| 310 |
+
dtype=np.float32,
|
| 311 |
+
)
|
| 312 |
+
for gen in models:
|
| 313 |
+
softmax_sum += gen(model_input, training=False).numpy()
|
| 314 |
+
|
| 315 |
+
softmax_avg = softmax_sum / len(models) # (1, H, W, C)
|
| 316 |
+
pred_slice = np.argmax(softmax_avg, axis=-1)[0] # (H, W)
|
| 317 |
+
|
| 318 |
+
# Optional post-processing
|
| 319 |
+
if config.apply_postprocess:
|
| 320 |
+
pred_slice = post_process_pred(
|
| 321 |
+
pred_slice,
|
| 322 |
+
num_classes=config.num_classes,
|
| 323 |
+
min_object_size=config.min_object_size,
|
| 324 |
+
closing_kernel_size=config.closing_kernel_size,
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
# If model output is 256×256 but original slice is different size, resize back
|
| 328 |
+
if pred_slice.shape != (H, W):
|
| 329 |
+
import cv2
|
| 330 |
+
pred_slice = cv2.resize(
|
| 331 |
+
pred_slice.astype(np.float32), (W, H),
|
| 332 |
+
interpolation=cv2.INTER_NEAREST,
|
| 333 |
+
).astype(np.uint8)
|
| 334 |
+
|
| 335 |
+
pred_volume[:, :, s] = pred_slice
|
| 336 |
+
|
| 337 |
+
# Binary masks
|
| 338 |
+
# Specialized_GM = class 1 in 2-class
|
| 339 |
+
gm_volume[:, :, s] = (pred_slice == 1).astype(np.uint8)
|
| 340 |
+
|
| 341 |
+
# ── Save outputs ─────────────────────────────────────────────────────────
|
| 342 |
+
gm_path = gm_out_dir / f"{patient_id}_gm_mask.nii.gz"
|
| 343 |
+
|
| 344 |
+
save_binary_nifti(gm_volume, gm_path, flair_img)
|
| 345 |
+
|
| 346 |
+
n_gm = int(gm_volume.sum())
|
| 347 |
+
print(
|
| 348 |
+
f" Patient {patient_id}: "
|
| 349 |
+
f"GM voxels = {n_gm:6d}"
|
| 350 |
+
)
|
| 351 |
+
print(f" → {gm_path}")
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
###################### Main Prediction Pipeline ######################
|
| 355 |
+
|
| 356 |
+
def run_prediction(config: PredictConfig, fold_ids: list = None):
|
| 357 |
+
"""
|
| 358 |
+
Full prediction pipeline for all patients in HC and MS cohorts.
|
| 359 |
+
|
| 360 |
+
Args:
|
| 361 |
+
config : PredictConfig object.
|
| 362 |
+
fold_ids : List of fold IDs to ensemble (e.g. [0, 1, 2, 3]).
|
| 363 |
+
If None, defaults to [0, 1, 2, 3].
|
| 364 |
+
"""
|
| 365 |
+
if fold_ids is None:
|
| 366 |
+
fold_ids = [0]
|
| 367 |
+
|
| 368 |
+
# ── Load all fold models ─────────────────────────────────────────────────
|
| 369 |
+
print(f"\nLoading models for folds: {fold_ids}")
|
| 370 |
+
models = []
|
| 371 |
+
for fold_id in fold_ids:
|
| 372 |
+
gen = load_model(config, fold_id)
|
| 373 |
+
models.append(gen)
|
| 374 |
+
print(f"✅ {len(models)} model(s) loaded\n")
|
| 375 |
+
|
| 376 |
+
# ── Iterate over cohorts ─────────────────────────────────────────────────
|
| 377 |
+
for cohort_name, cohort_flair_dir in config.cohorts.items():
|
| 378 |
+
files_dir = cohort_flair_dir / "files"
|
| 379 |
+
brain_masks_dir = cohort_flair_dir / "Brain_Masks"
|
| 380 |
+
gm_out_dir = cohort_flair_dir / "GM_Masks"
|
| 381 |
+
|
| 382 |
+
# Create output directories
|
| 383 |
+
gm_out_dir.mkdir(parents=True, exist_ok=True)
|
| 384 |
+
|
| 385 |
+
# Discover patients from the files directory
|
| 386 |
+
flair_files = sorted(files_dir.glob("*.nii.gz"))
|
| 387 |
+
if not flair_files:
|
| 388 |
+
print(f"⚠️ No FLAIR files found in {files_dir} – skipping {cohort_name} cohort")
|
| 389 |
+
continue
|
| 390 |
+
|
| 391 |
+
print(f"\n{'='*70}")
|
| 392 |
+
print(f"COHORT: {cohort_name} ({len(flair_files)} patients found)")
|
| 393 |
+
print(f" FLAIR dir : {files_dir}")
|
| 394 |
+
print(f" Brain masks dir : {brain_masks_dir}")
|
| 395 |
+
print(f" Output GM dir : {gm_out_dir}")
|
| 396 |
+
print(f"{'='*70}")
|
| 397 |
+
|
| 398 |
+
skipped = 0
|
| 399 |
+
for flair_path in tqdm(flair_files, desc=f"{cohort_name} patients"):
|
| 400 |
+
# Extract 6-digit patient ID from filename
|
| 401 |
+
patient_id = flair_path.stem.replace(".nii", "") # handles double .nii.gz
|
| 402 |
+
|
| 403 |
+
brain_mask_path = brain_masks_dir / f"{patient_id}_brain_mask.nii.gz"
|
| 404 |
+
|
| 405 |
+
if not brain_mask_path.exists(): # or patient_id != '110214':
|
| 406 |
+
print(
|
| 407 |
+
f"\n ⚠️ Brain mask not found for patient {patient_id} "
|
| 408 |
+
f"(expected: {brain_mask_path}) – skipping"
|
| 409 |
+
)
|
| 410 |
+
skipped += 1
|
| 411 |
+
continue
|
| 412 |
+
|
| 413 |
+
try:
|
| 414 |
+
predict_patient(
|
| 415 |
+
patient_id=patient_id,
|
| 416 |
+
flair_path=flair_path,
|
| 417 |
+
brain_mask_path=brain_mask_path,
|
| 418 |
+
models=models,
|
| 419 |
+
config=config,
|
| 420 |
+
gm_out_dir=gm_out_dir,
|
| 421 |
+
)
|
| 422 |
+
except Exception as exc:
|
| 423 |
+
print(f"\n ❌ Error processing patient {patient_id}: {exc}")
|
| 424 |
+
skipped += 1
|
| 425 |
+
|
| 426 |
+
done = len(flair_files) - skipped
|
| 427 |
+
print(
|
| 428 |
+
f"\n ✅ {cohort_name} cohort done: {done} predicted, {skipped} skipped\n"
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
print("\n" + "="*70)
|
| 432 |
+
print("ALL COHORTS PROCESSED")
|
| 433 |
+
print("="*70)
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
###################### Entry Point ######################
|
| 437 |
+
|
| 438 |
+
if __name__ == "__main__":
|
| 439 |
+
|
| 440 |
+
parser = argparse.ArgumentParser(
|
| 441 |
+
description="P4 – Predict Specialized_GM for new HC / MS cohort data"
|
| 442 |
+
)
|
| 443 |
+
parser.add_argument("--variant", type=int, default=1)
|
| 444 |
+
parser.add_argument("--preprocessing", type=str, default="standard")
|
| 445 |
+
parser.add_argument("--class_scenario", type=str, default="binary",
|
| 446 |
+
choices=["binary"])
|
| 447 |
+
parser.add_argument("--architecture", type=str, default="unet",
|
| 448 |
+
choices=["unet", "attnunet", "dlv3unet", "transunet"])
|
| 449 |
+
parser.add_argument("--model_name", type=str, default="best_dice_generator.h5")
|
| 450 |
+
parser.add_argument("--folds", type=int, nargs="+", default=[0],
|
| 451 |
+
help="Fold IDs to ensemble (e.g. --folds 0 1 2 3)")
|
| 452 |
+
parser.add_argument("--slice_start", type=int, default=1,
|
| 453 |
+
help="First slice to predict (1-based, inclusive)")
|
| 454 |
+
parser.add_argument("--slice_end", type=int, default=20,
|
| 455 |
+
help="Last slice to predict (1-based, inclusive)")
|
| 456 |
+
parser.add_argument("--data_root", type=str, default="/mnt/d/TEMP_P4")
|
| 457 |
+
parser.add_argument("--no_postprocess", action="store_false",
|
| 458 |
+
help="Disable morphological post-processing")
|
| 459 |
+
parser.add_argument("--min_object_size", type=int, default=5)
|
| 460 |
+
parser.add_argument("--closing_size", type=int, default=2)
|
| 461 |
+
args = parser.parse_args()
|
| 462 |
+
|
| 463 |
+
config = PredictConfig(
|
| 464 |
+
variant=args.variant,
|
| 465 |
+
preprocessing=args.preprocessing,
|
| 466 |
+
class_scenario=args.class_scenario,
|
| 467 |
+
architecture_name=args.architecture,
|
| 468 |
+
model_name=args.model_name,
|
| 469 |
+
slice_start=args.slice_start,
|
| 470 |
+
slice_end=args.slice_end,
|
| 471 |
+
data_root=args.data_root,
|
| 472 |
+
apply_postprocess=not args.no_postprocess,
|
| 473 |
+
min_object_size=args.min_object_size,
|
| 474 |
+
closing_kernel_size=args.closing_size,
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
run_prediction(config, fold_ids=args.folds)
|
models/for_GM/model_training_scripts/unet_model.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
###################### Libraries ######################
|
| 2 |
+
# Deep Learning
|
| 3 |
+
import keras
|
| 4 |
+
from keras.models import Model
|
| 5 |
+
from keras.layers import Input, Conv2D, MaxPooling2D, Conv2DTranspose, concatenate
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def build_unet_3class(input_shape=(256, 256, 1), num_classes=3):
|
| 9 |
+
"""Enhanced U-Net architecture with batch normalization and dropout"""
|
| 10 |
+
inputs = Input(input_shape)
|
| 11 |
+
|
| 12 |
+
# Encoder with batch normalization
|
| 13 |
+
c1 = Conv2D(64, 3, activation='relu', padding='same')(inputs)
|
| 14 |
+
# c1 = keras.layers.BatchNormalization()(c1)
|
| 15 |
+
c1 = Conv2D(64, 3, activation='relu', padding='same')(c1)
|
| 16 |
+
# c1 = keras.layers.BatchNormalization()(c1)
|
| 17 |
+
p1 = MaxPooling2D()(c1)
|
| 18 |
+
p1 = keras.layers.Dropout(0.1)(p1)
|
| 19 |
+
|
| 20 |
+
c2 = Conv2D(128, 3, activation='relu', padding='same')(p1)
|
| 21 |
+
# c2 = keras.layers.BatchNormalization()(c2)
|
| 22 |
+
c2 = Conv2D(128, 3, activation='relu', padding='same')(c2)
|
| 23 |
+
# c2 = keras.layers.BatchNormalization()(c2)
|
| 24 |
+
p2 = MaxPooling2D()(c2)
|
| 25 |
+
p2 = keras.layers.Dropout(0.1)(p2)
|
| 26 |
+
|
| 27 |
+
c3 = Conv2D(256, 3, activation='relu', padding='same')(p2)
|
| 28 |
+
# c3 = keras.layers.BatchNormalization()(c3)
|
| 29 |
+
c3 = Conv2D(256, 3, activation='relu', padding='same')(c3)
|
| 30 |
+
# c3 = keras.layers.BatchNormalization()(c3)
|
| 31 |
+
p3 = MaxPooling2D()(c3)
|
| 32 |
+
p3 = keras.layers.Dropout(0.2)(p3)
|
| 33 |
+
|
| 34 |
+
c4 = Conv2D(512, 3, activation='relu', padding='same')(p3)
|
| 35 |
+
# c4 = keras.layers.BatchNormalization()(c4)
|
| 36 |
+
c4 = Conv2D(512, 3, activation='relu', padding='same')(c4)
|
| 37 |
+
# c4 = keras.layers.BatchNormalization()(c4)
|
| 38 |
+
p4 = MaxPooling2D()(c4)
|
| 39 |
+
p4 = keras.layers.Dropout(0.2)(p4)
|
| 40 |
+
|
| 41 |
+
# Bottleneck
|
| 42 |
+
c5 = Conv2D(1024, 3, activation='relu', padding='same')(p4)
|
| 43 |
+
# c5 = keras.layers.BatchNormalization()(c5)
|
| 44 |
+
c5 = Conv2D(1024, 3, activation='relu', padding='same')(c5)
|
| 45 |
+
# c5 = keras.layers.BatchNormalization()(c5)
|
| 46 |
+
c5 = keras.layers.Dropout(0.3)(c5)
|
| 47 |
+
|
| 48 |
+
# Decoder
|
| 49 |
+
u6 = Conv2DTranspose(512, 2, strides=2, padding='same')(c5)
|
| 50 |
+
u6 = concatenate([u6, c4])
|
| 51 |
+
u6 = keras.layers.Dropout(0.2)(u6)
|
| 52 |
+
c6 = Conv2D(512, 3, activation='relu', padding='same')(u6)
|
| 53 |
+
# c6 = keras.layers.BatchNormalization()(c6)
|
| 54 |
+
c6 = Conv2D(512, 3, activation='relu', padding='same')(c6)
|
| 55 |
+
# c6 = keras.layers.BatchNormalization()(c6)
|
| 56 |
+
|
| 57 |
+
u7 = Conv2DTranspose(256, 2, strides=2, padding='same')(c6)
|
| 58 |
+
u7 = concatenate([u7, c3])
|
| 59 |
+
u7 = keras.layers.Dropout(0.2)(u7)
|
| 60 |
+
c7 = Conv2D(256, 3, activation='relu', padding='same')(u7)
|
| 61 |
+
# c7 = keras.layers.BatchNormalization()(c7)
|
| 62 |
+
c7 = Conv2D(256, 3, activation='relu', padding='same')(c7)
|
| 63 |
+
# c7 = keras.layers.BatchNormalization()(c7)
|
| 64 |
+
|
| 65 |
+
u8 = Conv2DTranspose(128, 2, strides=2, padding='same')(c7)
|
| 66 |
+
u8 = concatenate([u8, c2])
|
| 67 |
+
u8 = keras.layers.Dropout(0.1)(u8)
|
| 68 |
+
c8 = Conv2D(128, 3, activation='relu', padding='same')(u8)
|
| 69 |
+
# c8 = keras.layers.BatchNormalization()(c8)
|
| 70 |
+
c8 = Conv2D(128, 3, activation='relu', padding='same')(c8)
|
| 71 |
+
# c8 = keras.layers.BatchNormalization()(c8)
|
| 72 |
+
|
| 73 |
+
u9 = Conv2DTranspose(64, 2, strides=2, padding='same')(c8)
|
| 74 |
+
u9 = concatenate([u9, c1])
|
| 75 |
+
u9 = keras.layers.Dropout(0.1)(u9)
|
| 76 |
+
c9 = Conv2D(64, 3, activation='relu', padding='same')(u9)
|
| 77 |
+
# c9 = keras.layers.BatchNormalization()(c9)
|
| 78 |
+
c9 = Conv2D(64, 3, activation='relu', padding='same')(c9)
|
| 79 |
+
# c9 = keras.layers.BatchNormalization()(c9)
|
| 80 |
+
|
| 81 |
+
# Output layer
|
| 82 |
+
if num_classes == 1:
|
| 83 |
+
outputs = Conv2D(1, 1, activation='sigmoid')(c9)
|
| 84 |
+
else:
|
| 85 |
+
outputs = Conv2D(num_classes, 1, activation='softmax')(c9)
|
| 86 |
+
|
| 87 |
+
return Model(inputs, outputs, name='UNet')
|
models/for_GM/model_training_scripts/utility_functions.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
P1 Article - Utility Functions
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
Developer:
|
| 6 |
+
"Mahdi Bashiri Bawil"
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import gc
|
| 10 |
+
import tensorflow as tf
|
| 11 |
+
from tensorflow.keras import backend as K
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
print("TensorFlow Version:", tf.__version__)
|
| 15 |
+
|
| 16 |
+
###################### GPU Configuration ######################
|
| 17 |
+
|
| 18 |
+
# Configure GPU memory growth
|
| 19 |
+
physical_devices = tf.config.list_physical_devices('GPU')
|
| 20 |
+
if physical_devices:
|
| 21 |
+
try:
|
| 22 |
+
for device in physical_devices:
|
| 23 |
+
tf.config.experimental.set_memory_growth(device, True)
|
| 24 |
+
print("✅ GPU memory growth enabled")
|
| 25 |
+
print(f" Available GPUs: {len(physical_devices)}")
|
| 26 |
+
except RuntimeError as e:
|
| 27 |
+
print(f"GPU configuration error: {e}")
|
| 28 |
+
else:
|
| 29 |
+
print("⚠️ No GPU detected - training will be slow")
|
| 30 |
+
|
| 31 |
+
"""
|
| 32 |
+
GPU Memory Management for Sequential Experiments
|
| 33 |
+
To properly release memory between experiments
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def clear_gpu_memory():
|
| 38 |
+
"""
|
| 39 |
+
Comprehensive GPU memory cleanup between experiments
|
| 40 |
+
Call this after each experiment completes
|
| 41 |
+
"""
|
| 42 |
+
print("\n" + "="*70)
|
| 43 |
+
print("CLEANING UP GPU MEMORY")
|
| 44 |
+
print("="*70)
|
| 45 |
+
|
| 46 |
+
# Clear Keras session
|
| 47 |
+
K.clear_session()
|
| 48 |
+
print("✅ Cleared Keras session")
|
| 49 |
+
|
| 50 |
+
# Force garbage collection multiple times
|
| 51 |
+
for _ in range(3):
|
| 52 |
+
gc.collect()
|
| 53 |
+
print("✅ Ran garbage collection (3 passes)")
|
| 54 |
+
|
| 55 |
+
# Reset TensorFlow graphs
|
| 56 |
+
tf.compat.v1.reset_default_graph()
|
| 57 |
+
print("✅ Reset default graph")
|
| 58 |
+
|
| 59 |
+
# Additional cleanup for TF 2.x
|
| 60 |
+
try:
|
| 61 |
+
# Clear any cached tensors
|
| 62 |
+
tf.config.experimental.reset_memory_stats('GPU:0')
|
| 63 |
+
print("✅ Reset GPU memory stats")
|
| 64 |
+
except:
|
| 65 |
+
pass
|
| 66 |
+
|
| 67 |
+
# CRITICAL: Reset GPU memory allocator
|
| 68 |
+
# This forces TensorFlow to release memory back to the system
|
| 69 |
+
try:
|
| 70 |
+
physical_devices = tf.config.list_physical_devices('GPU')
|
| 71 |
+
if physical_devices:
|
| 72 |
+
# Disable and re-enable memory growth to flush allocator
|
| 73 |
+
for device in physical_devices:
|
| 74 |
+
tf.config.experimental.set_memory_growth(device, False)
|
| 75 |
+
tf.config.experimental.set_memory_growth(device, True)
|
| 76 |
+
print("✅ Reset memory growth (flushed allocator)")
|
| 77 |
+
except Exception as e:
|
| 78 |
+
print(f"⚠️ Could not reset memory growth: {e}")
|
| 79 |
+
|
| 80 |
+
print("="*70 + "\n")
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def get_gpu_memory_info():
|
| 84 |
+
"""
|
| 85 |
+
Print current GPU memory usage
|
| 86 |
+
Useful for monitoring memory leaks
|
| 87 |
+
"""
|
| 88 |
+
try:
|
| 89 |
+
gpu_devices = tf.config.list_physical_devices('GPU')
|
| 90 |
+
if gpu_devices:
|
| 91 |
+
for device in gpu_devices:
|
| 92 |
+
details = tf.config.experimental.get_memory_info(device.name.replace('/physical_device:', ''))
|
| 93 |
+
current_mb = details['current'] / 1024**2
|
| 94 |
+
peak_mb = details['peak'] / 1024**2
|
| 95 |
+
print(f"GPU Memory - Current: {current_mb:.1f} MB, Peak: {peak_mb:.1f} MB")
|
| 96 |
+
except Exception as e:
|
| 97 |
+
print(f"Could not get GPU memory info: {e}")
|
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_001_1.png
ADDED
|
Git LFS Details
|
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_001_2.png
ADDED
|
Git LFS Details
|
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_002_1.png
ADDED
|
Git LFS Details
|
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_002_2.png
ADDED
|
Git LFS Details
|
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_003_1.png
ADDED
|
Git LFS Details
|
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_003_2.png
ADDED
|
Git LFS Details
|
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_004_1.png
ADDED
|
Git LFS Details
|
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_004_2.png
ADDED
|
Git LFS Details
|
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_005_1.png
ADDED
|
Git LFS Details
|
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_005_2.png
ADDED
|
Git LFS Details
|
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_006_1.png
ADDED
|
Git LFS Details
|
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_006_2.png
ADDED
|
Git LFS Details
|
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_007_1.png
ADDED
|
Git LFS Details
|
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_007_2.png
ADDED
|
Git LFS Details
|
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_008_1.png
ADDED
|
Git LFS Details
|
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_008_2.png
ADDED
|
Git LFS Details
|
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_009_1.png
ADDED
|
Git LFS Details
|
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_009_2.png
ADDED
|
Git LFS Details
|
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_010_1.png
ADDED
|
Git LFS Details
|
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_010_2.png
ADDED
|
Git LFS Details
|
models/for_GM/results_fold_0_var_1_bet_zscore_gm/models/standard_binary/fold_0/best_dice_discriminator.h5
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3d37fd879442368aaba7813f2549a1dda9c2376be2e6ec000d6352b2901e7207
|
| 3 |
+
size 11107072
|
models/for_GM/results_fold_0_var_1_bet_zscore_gm/models/standard_binary/fold_0/best_dice_generator.h5
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5b28647f12d0639cb2f15a03e6f9334c45dfff35b25cf90df8504dd10d931598
|
| 3 |
+
size 124213136
|
models/for_GM/results_fold_0_var_1_bet_zscore_gm/models/standard_binary/fold_0/config.json
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"variant": 1,
|
| 3 |
+
"variant_name": "Multiclass_AttentionD_AdaptiveLoss",
|
| 4 |
+
"preprocessing": "standard",
|
| 5 |
+
"class_scenario": "binary",
|
| 6 |
+
"fold_id": 0,
|
| 7 |
+
"num_classes": 2,
|
| 8 |
+
"batch_size": 4,
|
| 9 |
+
"epochs": 20,
|
| 10 |
+
"lambda_seg": 50,
|
| 11 |
+
"lambda_gan": 1,
|
| 12 |
+
"focal_gamma": 0.5,
|
| 13 |
+
"beta_threshold": 0.25,
|
| 14 |
+
"beta_smoothness": 0.05,
|
| 15 |
+
"learning_rate": 0.0002,
|
| 16 |
+
"beta_1": 0.9,
|
| 17 |
+
"attention_weight": 2.0,
|
| 18 |
+
"innovation": "Phase-transitioning segmentation loss (Weighted CE \u2192 Focal Loss)"
|
| 19 |
+
}
|
models/for_GM/results_fold_0_var_1_bet_zscore_gm/models/standard_binary/fold_0/download_models.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Visit our Hugging Face link for downloading the trained models.
|
models/for_GM/results_fold_0_var_1_bet_zscore_gm/models/standard_binary/fold_0/history.json
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"gen_total_loss": [
|
| 3 |
+
16.160808563232422,
|
| 4 |
+
15.84605598449707,
|
| 5 |
+
15.596329689025879,
|
| 6 |
+
15.1962308883667,
|
| 7 |
+
15.236580848693848
|
| 8 |
+
],
|
| 9 |
+
"gen_gan_loss": [
|
| 10 |
+
1.4460713863372803,
|
| 11 |
+
1.3656171560287476,
|
| 12 |
+
1.3531222343444824,
|
| 13 |
+
1.1605579853057861,
|
| 14 |
+
1.2858413457870483
|
| 15 |
+
],
|
| 16 |
+
"gen_seg_loss": [
|
| 17 |
+
0.2942947447299957,
|
| 18 |
+
0.28960874676704407,
|
| 19 |
+
0.2848641574382782,
|
| 20 |
+
0.2807134687900543,
|
| 21 |
+
0.27901479601860046
|
| 22 |
+
],
|
| 23 |
+
"gen_wce_loss": [
|
| 24 |
+
0.427320659160614,
|
| 25 |
+
0.5918631553649902,
|
| 26 |
+
0.6427512764930725,
|
| 27 |
+
0.6484421491622925,
|
| 28 |
+
0.6594918966293335
|
| 29 |
+
],
|
| 30 |
+
"gen_focal_loss": [
|
| 31 |
+
0.2933984398841858,
|
| 32 |
+
0.28957146406173706,
|
| 33 |
+
0.28486335277557373,
|
| 34 |
+
0.2807134687900543,
|
| 35 |
+
0.27901479601860046
|
| 36 |
+
],
|
| 37 |
+
"disc_loss": [
|
| 38 |
+
0.9733297228813171,
|
| 39 |
+
0.9525837898254395,
|
| 40 |
+
0.9791364669799805,
|
| 41 |
+
1.0745328664779663,
|
| 42 |
+
0.9908205270767212
|
| 43 |
+
],
|
| 44 |
+
"val_loss": [
|
| 45 |
+
0.2753802239894867,
|
| 46 |
+
0.28040140867233276,
|
| 47 |
+
0.28456518054008484,
|
| 48 |
+
0.28689467906951904,
|
| 49 |
+
0.2823749780654907
|
| 50 |
+
],
|
| 51 |
+
"beta_value": [
|
| 52 |
+
0.9933071732521057,
|
| 53 |
+
0.9998766183853149,
|
| 54 |
+
0.9999977350234985,
|
| 55 |
+
1.0,
|
| 56 |
+
1.0
|
| 57 |
+
],
|
| 58 |
+
"val_metrics": [
|
| 59 |
+
{
|
| 60 |
+
"dice": {
|
| 61 |
+
"class_0": 0.9691289499147928,
|
| 62 |
+
"class_1": 0.752369057690533,
|
| 63 |
+
"mean": 0.8607490038026628
|
| 64 |
+
},
|
| 65 |
+
"precision": {
|
| 66 |
+
"class_0": 0.9797553013827063,
|
| 67 |
+
"class_1": 0.6921517437594977,
|
| 68 |
+
"mean": 0.835953522571102
|
| 69 |
+
},
|
| 70 |
+
"recall": {
|
| 71 |
+
"class_0": 0.9587306304286759,
|
| 72 |
+
"class_1": 0.8240626410757951,
|
| 73 |
+
"mean": 0.8913966357522355
|
| 74 |
+
}
|
| 75 |
+
},
|
| 76 |
+
{
|
| 77 |
+
"dice": {
|
| 78 |
+
"class_0": 0.9697302587837198,
|
| 79 |
+
"class_1": 0.7479798920589907,
|
| 80 |
+
"mean": 0.8588550754213553
|
| 81 |
+
},
|
| 82 |
+
"precision": {
|
| 83 |
+
"class_0": 0.976320417478653,
|
| 84 |
+
"class_1": 0.708180715854344,
|
| 85 |
+
"mean": 0.8422505666664984
|
| 86 |
+
},
|
| 87 |
+
"recall": {
|
| 88 |
+
"class_0": 0.9632284706744223,
|
| 89 |
+
"class_1": 0.7925187994877733,
|
| 90 |
+
"mean": 0.8778736350810978
|
| 91 |
+
}
|
| 92 |
+
},
|
| 93 |
+
{
|
| 94 |
+
"dice": {
|
| 95 |
+
"class_0": 0.9700212734386713,
|
| 96 |
+
"class_1": 0.7450693229965066,
|
| 97 |
+
"mean": 0.857545298217589
|
| 98 |
+
},
|
| 99 |
+
"precision": {
|
| 100 |
+
"class_0": 0.9743977498147104,
|
| 101 |
+
"class_1": 0.7176589260174784,
|
| 102 |
+
"mean": 0.8460283379160944
|
| 103 |
+
},
|
| 104 |
+
"recall": {
|
| 105 |
+
"class_0": 0.9656839348841134,
|
| 106 |
+
"class_1": 0.7746567035718136,
|
| 107 |
+
"mean": 0.8701703192279635
|
| 108 |
+
}
|
| 109 |
+
},
|
| 110 |
+
{
|
| 111 |
+
"dice": {
|
| 112 |
+
"class_0": 0.9707925785575585,
|
| 113 |
+
"class_1": 0.7425312591148605,
|
| 114 |
+
"mean": 0.8566619188362095
|
| 115 |
+
},
|
| 116 |
+
"precision": {
|
| 117 |
+
"class_0": 0.9715280350271916,
|
| 118 |
+
"class_1": 0.7376090600933536,
|
| 119 |
+
"mean": 0.8545685475602727
|
| 120 |
+
},
|
| 121 |
+
"recall": {
|
| 122 |
+
"class_0": 0.9700582347414827,
|
| 123 |
+
"class_1": 0.7475195929191307,
|
| 124 |
+
"mean": 0.8587889138303066
|
| 125 |
+
}
|
| 126 |
+
},
|
| 127 |
+
{
|
| 128 |
+
"dice": {
|
| 129 |
+
"class_0": 0.9707689737822686,
|
| 130 |
+
"class_1": 0.7466375133242413,
|
| 131 |
+
"mean": 0.8587032435532549
|
| 132 |
+
},
|
| 133 |
+
"precision": {
|
| 134 |
+
"class_0": 0.9731953621432797,
|
| 135 |
+
"class_1": 0.730843840931104,
|
| 136 |
+
"mean": 0.8520196015371919
|
| 137 |
+
},
|
| 138 |
+
"recall": {
|
| 139 |
+
"class_0": 0.9683546543618544,
|
| 140 |
+
"class_1": 0.7631288712746304,
|
| 141 |
+
"mean": 0.8657417628182424
|
| 142 |
+
}
|
| 143 |
+
}
|
| 144 |
+
]
|
| 145 |
+
}
|