Bawil commited on
Commit
4266ba2
·
verified ·
1 Parent(s): 57bfb9e

Upload 33 files

Browse files
Files changed (34) hide show
  1. .gitattributes +20 -0
  2. models/for_GM/class_weights_gm/class_weights_fold0_standard_binary.json +23 -0
  3. models/for_GM/data_splits_sp_gm/SP_GM_fold_assignments.json +1295 -0
  4. models/for_GM/model_training_scripts/p1_compute_class_weights.py +336 -0
  5. models/for_GM/model_training_scripts/p1_data_loader.py +847 -0
  6. models/for_GM/model_training_scripts/p1_pix2pix_var5.py +1313 -0
  7. models/for_GM/model_training_scripts/p1_predict_new_data_gm.py +477 -0
  8. models/for_GM/model_training_scripts/unet_model.py +87 -0
  9. models/for_GM/model_training_scripts/utility_functions.py +97 -0
  10. models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_001_1.png +3 -0
  11. models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_001_2.png +3 -0
  12. models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_002_1.png +3 -0
  13. models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_002_2.png +3 -0
  14. models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_003_1.png +3 -0
  15. models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_003_2.png +3 -0
  16. models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_004_1.png +3 -0
  17. models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_004_2.png +3 -0
  18. models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_005_1.png +3 -0
  19. models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_005_2.png +3 -0
  20. models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_006_1.png +3 -0
  21. models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_006_2.png +3 -0
  22. models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_007_1.png +3 -0
  23. models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_007_2.png +3 -0
  24. models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_008_1.png +3 -0
  25. models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_008_2.png +3 -0
  26. models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_009_1.png +3 -0
  27. models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_009_2.png +3 -0
  28. models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_010_1.png +3 -0
  29. models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_010_2.png +3 -0
  30. models/for_GM/results_fold_0_var_1_bet_zscore_gm/models/standard_binary/fold_0/best_dice_discriminator.h5 +3 -0
  31. models/for_GM/results_fold_0_var_1_bet_zscore_gm/models/standard_binary/fold_0/best_dice_generator.h5 +3 -0
  32. models/for_GM/results_fold_0_var_1_bet_zscore_gm/models/standard_binary/fold_0/config.json +19 -0
  33. models/for_GM/results_fold_0_var_1_bet_zscore_gm/models/standard_binary/fold_0/download_models.txt +1 -0
  34. models/for_GM/results_fold_0_var_1_bet_zscore_gm/models/standard_binary/fold_0/history.json +145 -0
.gitattributes CHANGED
@@ -33,3 +33,23 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_001_1.png filter=lfs diff=lfs merge=lfs -text
37
+ models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_001_2.png filter=lfs diff=lfs merge=lfs -text
38
+ models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_002_1.png filter=lfs diff=lfs merge=lfs -text
39
+ models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_002_2.png filter=lfs diff=lfs merge=lfs -text
40
+ models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_003_1.png filter=lfs diff=lfs merge=lfs -text
41
+ models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_003_2.png filter=lfs diff=lfs merge=lfs -text
42
+ models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_004_1.png filter=lfs diff=lfs merge=lfs -text
43
+ models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_004_2.png filter=lfs diff=lfs merge=lfs -text
44
+ models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_005_1.png filter=lfs diff=lfs merge=lfs -text
45
+ models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_005_2.png filter=lfs diff=lfs merge=lfs -text
46
+ models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_006_1.png filter=lfs diff=lfs merge=lfs -text
47
+ models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_006_2.png filter=lfs diff=lfs merge=lfs -text
48
+ models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_007_1.png filter=lfs diff=lfs merge=lfs -text
49
+ models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_007_2.png filter=lfs diff=lfs merge=lfs -text
50
+ models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_008_1.png filter=lfs diff=lfs merge=lfs -text
51
+ models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_008_2.png filter=lfs diff=lfs merge=lfs -text
52
+ models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_009_1.png filter=lfs diff=lfs merge=lfs -text
53
+ models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_009_2.png filter=lfs diff=lfs merge=lfs -text
54
+ models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_010_1.png filter=lfs diff=lfs merge=lfs -text
55
+ models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_010_2.png filter=lfs diff=lfs merge=lfs -text
models/for_GM/class_weights_gm/class_weights_fold0_standard_binary.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fold_id": 0,
3
+ "class_scenario": "binary",
4
+ "preprocessing": "standard",
5
+ "num_classes": 2,
6
+ "total_pixels": 88539136,
7
+ "class_pixel_counts": [
8
+ 79575838,
9
+ 8963298
10
+ ],
11
+ "class_frequencies": [
12
+ 0.8987645644068629,
13
+ 0.10123543559313702
14
+ ],
15
+ "class_weights": [
16
+ 0.20247246624134155,
17
+ 1.7975275337586585
18
+ ],
19
+ "class_names": [
20
+ "Background",
21
+ "Specialized GM"
22
+ ]
23
+ }
models/for_GM/data_splits_sp_gm/SP_GM_fold_assignments.json ADDED
@@ -0,0 +1,1295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_patients": 268,
4
+ "test_patients": 26,
5
+ "trainval_patients": 242,
6
+ "n_folds": 5,
7
+ "random_seed": 42,
8
+ "datasets": [
9
+ "Local_SAI_GM_sp"
10
+ ]
11
+ },
12
+ "test_set": {
13
+ "patients": [
14
+ "117524",
15
+ "132287",
16
+ "105597",
17
+ "120429",
18
+ "117949",
19
+ "126395",
20
+ "134240",
21
+ "120907",
22
+ "106506",
23
+ "110784",
24
+ "118754",
25
+ "112997",
26
+ "112730",
27
+ "129466",
28
+ "105911",
29
+ "111008",
30
+ "129008",
31
+ "129044",
32
+ "110543",
33
+ "117276",
34
+ "114454",
35
+ "104474",
36
+ "114770",
37
+ "130578",
38
+ "116740",
39
+ "107680"
40
+ ],
41
+ "n_patients": 26
42
+ },
43
+ "folds": {
44
+ "fold_0": {
45
+ "train_patients": [
46
+ "101228",
47
+ "101627",
48
+ "102035",
49
+ "102313",
50
+ "104252",
51
+ "104280",
52
+ "104447",
53
+ "104453",
54
+ "104670",
55
+ "104797",
56
+ "104810",
57
+ "104871",
58
+ "105074",
59
+ "105549",
60
+ "105755",
61
+ "105917",
62
+ "105978",
63
+ "106270",
64
+ "106536",
65
+ "106639",
66
+ "106780",
67
+ "106976",
68
+ "107130",
69
+ "107455",
70
+ "107508",
71
+ "107539",
72
+ "107630",
73
+ "107966",
74
+ "107997",
75
+ "108295",
76
+ "108344",
77
+ "108444",
78
+ "108726",
79
+ "108975",
80
+ "109141",
81
+ "109267",
82
+ "109395",
83
+ "109654",
84
+ "109816",
85
+ "109923",
86
+ "109944",
87
+ "110012",
88
+ "110157",
89
+ "110218",
90
+ "110280",
91
+ "110327",
92
+ "110497",
93
+ "111140",
94
+ "111189",
95
+ "111489",
96
+ "111691",
97
+ "111852",
98
+ "112414",
99
+ "112657",
100
+ "112659",
101
+ "112765",
102
+ "112776",
103
+ "113394",
104
+ "114058",
105
+ "114128",
106
+ "114266",
107
+ "114304",
108
+ "114525",
109
+ "114585",
110
+ "114903",
111
+ "114990",
112
+ "115588",
113
+ "115628",
114
+ "115788",
115
+ "115799",
116
+ "115841",
117
+ "115991",
118
+ "116236",
119
+ "116246",
120
+ "116577",
121
+ "116700",
122
+ "116914",
123
+ "116937",
124
+ "117314",
125
+ "117385",
126
+ "117814",
127
+ "118018",
128
+ "118078",
129
+ "118409",
130
+ "118450",
131
+ "118481",
132
+ "118605",
133
+ "118719",
134
+ "118755",
135
+ "119730",
136
+ "120638",
137
+ "120749",
138
+ "120857",
139
+ "121140",
140
+ "121404",
141
+ "121499",
142
+ "121620",
143
+ "121804",
144
+ "121921",
145
+ "122000",
146
+ "122316",
147
+ "122762",
148
+ "122884",
149
+ "123575",
150
+ "124187",
151
+ "124899",
152
+ "125198",
153
+ "125465",
154
+ "125567",
155
+ "125798",
156
+ "126228",
157
+ "126396",
158
+ "126445",
159
+ "126465",
160
+ "126494",
161
+ "126523",
162
+ "126542",
163
+ "126704",
164
+ "126768",
165
+ "126779",
166
+ "127096",
167
+ "127513",
168
+ "127758",
169
+ "127816",
170
+ "127897",
171
+ "128785",
172
+ "128901",
173
+ "129055",
174
+ "129100",
175
+ "129637",
176
+ "129739",
177
+ "130214",
178
+ "130282",
179
+ "130366",
180
+ "130371",
181
+ "130402",
182
+ "130556",
183
+ "130662",
184
+ "130801",
185
+ "131040",
186
+ "131231",
187
+ "131235",
188
+ "131364",
189
+ "131444",
190
+ "131494",
191
+ "131606",
192
+ "131636",
193
+ "131792",
194
+ "131924",
195
+ "132155",
196
+ "132207",
197
+ "132282",
198
+ "132296",
199
+ "132589",
200
+ "132597",
201
+ "132605",
202
+ "132920",
203
+ "133196",
204
+ "133338",
205
+ "133562",
206
+ "133710",
207
+ "133814",
208
+ "133850",
209
+ "133886",
210
+ "133934",
211
+ "133946",
212
+ "134032",
213
+ "134654",
214
+ "134728",
215
+ "134919",
216
+ "134955",
217
+ "135467",
218
+ "135503",
219
+ "135687",
220
+ "135695",
221
+ "135697",
222
+ "135725",
223
+ "135733",
224
+ "135830",
225
+ "135855",
226
+ "136104",
227
+ "136105",
228
+ "136175",
229
+ "136220",
230
+ "136310",
231
+ "136382",
232
+ "136793",
233
+ "136817",
234
+ "136966",
235
+ "136996",
236
+ "137104",
237
+ "137508",
238
+ "137675"
239
+ ],
240
+ "val_patients": [
241
+ "104420",
242
+ "104518",
243
+ "104520",
244
+ "104899",
245
+ "104937",
246
+ "105302",
247
+ "105465",
248
+ "106063",
249
+ "106200",
250
+ "106905",
251
+ "107233",
252
+ "107739",
253
+ "108807",
254
+ "110540",
255
+ "112055",
256
+ "112378",
257
+ "113046",
258
+ "113845",
259
+ "114836",
260
+ "116268",
261
+ "116768",
262
+ "118660",
263
+ "118807",
264
+ "119095",
265
+ "119224",
266
+ "120781",
267
+ "122020",
268
+ "122288",
269
+ "125626",
270
+ "127511",
271
+ "127545",
272
+ "127870",
273
+ "129164",
274
+ "129916",
275
+ "130308",
276
+ "130373",
277
+ "131919",
278
+ "132371",
279
+ "132812",
280
+ "132896",
281
+ "133340",
282
+ "134197",
283
+ "134555",
284
+ "135628",
285
+ "136144",
286
+ "136589",
287
+ "137168",
288
+ "137617",
289
+ "137624"
290
+ ],
291
+ "n_train": 193,
292
+ "n_val": 49
293
+ },
294
+ "fold_1": {
295
+ "train_patients": [
296
+ "101228",
297
+ "101627",
298
+ "102035",
299
+ "102313",
300
+ "104252",
301
+ "104420",
302
+ "104447",
303
+ "104453",
304
+ "104518",
305
+ "104520",
306
+ "104670",
307
+ "104810",
308
+ "104871",
309
+ "104899",
310
+ "104937",
311
+ "105074",
312
+ "105302",
313
+ "105465",
314
+ "105549",
315
+ "105755",
316
+ "105917",
317
+ "105978",
318
+ "106063",
319
+ "106200",
320
+ "106270",
321
+ "106536",
322
+ "106905",
323
+ "107130",
324
+ "107233",
325
+ "107455",
326
+ "107539",
327
+ "107630",
328
+ "107739",
329
+ "107966",
330
+ "107997",
331
+ "108295",
332
+ "108444",
333
+ "108726",
334
+ "108807",
335
+ "108975",
336
+ "109141",
337
+ "109267",
338
+ "109395",
339
+ "109654",
340
+ "109923",
341
+ "109944",
342
+ "110012",
343
+ "110280",
344
+ "110327",
345
+ "110497",
346
+ "110540",
347
+ "111140",
348
+ "111189",
349
+ "111489",
350
+ "111691",
351
+ "112055",
352
+ "112378",
353
+ "112659",
354
+ "112765",
355
+ "112776",
356
+ "113046",
357
+ "113394",
358
+ "113845",
359
+ "114058",
360
+ "114128",
361
+ "114266",
362
+ "114525",
363
+ "114585",
364
+ "114836",
365
+ "114903",
366
+ "115588",
367
+ "115788",
368
+ "115799",
369
+ "115841",
370
+ "115991",
371
+ "116236",
372
+ "116246",
373
+ "116268",
374
+ "116577",
375
+ "116768",
376
+ "116937",
377
+ "117314",
378
+ "117385",
379
+ "118018",
380
+ "118078",
381
+ "118450",
382
+ "118481",
383
+ "118605",
384
+ "118660",
385
+ "118755",
386
+ "118807",
387
+ "119095",
388
+ "119224",
389
+ "120749",
390
+ "120781",
391
+ "120857",
392
+ "121499",
393
+ "121620",
394
+ "121804",
395
+ "122020",
396
+ "122288",
397
+ "122316",
398
+ "122762",
399
+ "122884",
400
+ "123575",
401
+ "124899",
402
+ "125198",
403
+ "125465",
404
+ "125567",
405
+ "125626",
406
+ "125798",
407
+ "126396",
408
+ "126465",
409
+ "126494",
410
+ "126542",
411
+ "126704",
412
+ "126779",
413
+ "127096",
414
+ "127511",
415
+ "127513",
416
+ "127545",
417
+ "127870",
418
+ "127897",
419
+ "128785",
420
+ "129055",
421
+ "129100",
422
+ "129164",
423
+ "129739",
424
+ "129916",
425
+ "130214",
426
+ "130282",
427
+ "130308",
428
+ "130371",
429
+ "130373",
430
+ "130402",
431
+ "130556",
432
+ "130662",
433
+ "130801",
434
+ "131231",
435
+ "131444",
436
+ "131494",
437
+ "131606",
438
+ "131636",
439
+ "131792",
440
+ "131919",
441
+ "131924",
442
+ "132155",
443
+ "132207",
444
+ "132282",
445
+ "132296",
446
+ "132371",
447
+ "132589",
448
+ "132812",
449
+ "132896",
450
+ "132920",
451
+ "133196",
452
+ "133340",
453
+ "133562",
454
+ "133710",
455
+ "133814",
456
+ "133850",
457
+ "133886",
458
+ "134032",
459
+ "134197",
460
+ "134555",
461
+ "134654",
462
+ "134728",
463
+ "134919",
464
+ "134955",
465
+ "135467",
466
+ "135503",
467
+ "135628",
468
+ "135687",
469
+ "135695",
470
+ "135697",
471
+ "135725",
472
+ "135733",
473
+ "135830",
474
+ "136104",
475
+ "136144",
476
+ "136175",
477
+ "136220",
478
+ "136589",
479
+ "136793",
480
+ "136817",
481
+ "136966",
482
+ "136996",
483
+ "137104",
484
+ "137168",
485
+ "137508",
486
+ "137617",
487
+ "137624",
488
+ "137675"
489
+ ],
490
+ "val_patients": [
491
+ "104280",
492
+ "104797",
493
+ "106639",
494
+ "106780",
495
+ "106976",
496
+ "107508",
497
+ "108344",
498
+ "109816",
499
+ "110157",
500
+ "110218",
501
+ "111852",
502
+ "112414",
503
+ "112657",
504
+ "114304",
505
+ "114990",
506
+ "115628",
507
+ "116700",
508
+ "116914",
509
+ "117814",
510
+ "118409",
511
+ "118719",
512
+ "119730",
513
+ "120638",
514
+ "121140",
515
+ "121404",
516
+ "121921",
517
+ "122000",
518
+ "124187",
519
+ "126228",
520
+ "126445",
521
+ "126523",
522
+ "126768",
523
+ "127758",
524
+ "127816",
525
+ "128901",
526
+ "129637",
527
+ "130366",
528
+ "131040",
529
+ "131235",
530
+ "131364",
531
+ "132597",
532
+ "132605",
533
+ "133338",
534
+ "133934",
535
+ "133946",
536
+ "135855",
537
+ "136105",
538
+ "136310",
539
+ "136382"
540
+ ],
541
+ "n_train": 193,
542
+ "n_val": 49
543
+ },
544
+ "fold_2": {
545
+ "train_patients": [
546
+ "101627",
547
+ "102313",
548
+ "104280",
549
+ "104420",
550
+ "104447",
551
+ "104453",
552
+ "104518",
553
+ "104520",
554
+ "104797",
555
+ "104810",
556
+ "104871",
557
+ "104899",
558
+ "104937",
559
+ "105074",
560
+ "105302",
561
+ "105465",
562
+ "105549",
563
+ "105755",
564
+ "105978",
565
+ "106063",
566
+ "106200",
567
+ "106639",
568
+ "106780",
569
+ "106905",
570
+ "106976",
571
+ "107233",
572
+ "107455",
573
+ "107508",
574
+ "107630",
575
+ "107739",
576
+ "107966",
577
+ "107997",
578
+ "108344",
579
+ "108444",
580
+ "108726",
581
+ "108807",
582
+ "109141",
583
+ "109267",
584
+ "109395",
585
+ "109654",
586
+ "109816",
587
+ "109923",
588
+ "109944",
589
+ "110012",
590
+ "110157",
591
+ "110218",
592
+ "110280",
593
+ "110327",
594
+ "110497",
595
+ "110540",
596
+ "111489",
597
+ "111691",
598
+ "111852",
599
+ "112055",
600
+ "112378",
601
+ "112414",
602
+ "112657",
603
+ "112765",
604
+ "112776",
605
+ "113046",
606
+ "113394",
607
+ "113845",
608
+ "114304",
609
+ "114525",
610
+ "114585",
611
+ "114836",
612
+ "114903",
613
+ "114990",
614
+ "115628",
615
+ "115788",
616
+ "115799",
617
+ "115841",
618
+ "116236",
619
+ "116246",
620
+ "116268",
621
+ "116577",
622
+ "116700",
623
+ "116768",
624
+ "116914",
625
+ "117314",
626
+ "117814",
627
+ "118018",
628
+ "118078",
629
+ "118409",
630
+ "118450",
631
+ "118481",
632
+ "118605",
633
+ "118660",
634
+ "118719",
635
+ "118755",
636
+ "118807",
637
+ "119095",
638
+ "119224",
639
+ "119730",
640
+ "120638",
641
+ "120749",
642
+ "120781",
643
+ "121140",
644
+ "121404",
645
+ "121499",
646
+ "121804",
647
+ "121921",
648
+ "122000",
649
+ "122020",
650
+ "122288",
651
+ "122762",
652
+ "122884",
653
+ "123575",
654
+ "124187",
655
+ "124899",
656
+ "125198",
657
+ "125626",
658
+ "126228",
659
+ "126445",
660
+ "126523",
661
+ "126542",
662
+ "126768",
663
+ "127096",
664
+ "127511",
665
+ "127513",
666
+ "127545",
667
+ "127758",
668
+ "127816",
669
+ "127870",
670
+ "127897",
671
+ "128785",
672
+ "128901",
673
+ "129100",
674
+ "129164",
675
+ "129637",
676
+ "129739",
677
+ "129916",
678
+ "130214",
679
+ "130282",
680
+ "130308",
681
+ "130366",
682
+ "130371",
683
+ "130373",
684
+ "130402",
685
+ "130801",
686
+ "131040",
687
+ "131231",
688
+ "131235",
689
+ "131364",
690
+ "131444",
691
+ "131494",
692
+ "131792",
693
+ "131919",
694
+ "132155",
695
+ "132207",
696
+ "132282",
697
+ "132296",
698
+ "132371",
699
+ "132589",
700
+ "132597",
701
+ "132605",
702
+ "132812",
703
+ "132896",
704
+ "133196",
705
+ "133338",
706
+ "133340",
707
+ "133562",
708
+ "133710",
709
+ "133814",
710
+ "133850",
711
+ "133934",
712
+ "133946",
713
+ "134032",
714
+ "134197",
715
+ "134555",
716
+ "134654",
717
+ "134919",
718
+ "134955",
719
+ "135467",
720
+ "135503",
721
+ "135628",
722
+ "135697",
723
+ "135725",
724
+ "135830",
725
+ "135855",
726
+ "136104",
727
+ "136105",
728
+ "136144",
729
+ "136175",
730
+ "136220",
731
+ "136310",
732
+ "136382",
733
+ "136589",
734
+ "136966",
735
+ "137168",
736
+ "137508",
737
+ "137617",
738
+ "137624",
739
+ "137675"
740
+ ],
741
+ "val_patients": [
742
+ "101228",
743
+ "102035",
744
+ "104252",
745
+ "104670",
746
+ "105917",
747
+ "106270",
748
+ "106536",
749
+ "107130",
750
+ "107539",
751
+ "108295",
752
+ "108975",
753
+ "111140",
754
+ "111189",
755
+ "112659",
756
+ "114058",
757
+ "114128",
758
+ "114266",
759
+ "115588",
760
+ "115991",
761
+ "116937",
762
+ "117385",
763
+ "120857",
764
+ "121620",
765
+ "122316",
766
+ "125465",
767
+ "125567",
768
+ "125798",
769
+ "126396",
770
+ "126465",
771
+ "126494",
772
+ "126704",
773
+ "126779",
774
+ "129055",
775
+ "130556",
776
+ "130662",
777
+ "131606",
778
+ "131636",
779
+ "131924",
780
+ "132920",
781
+ "133886",
782
+ "134728",
783
+ "135687",
784
+ "135695",
785
+ "135733",
786
+ "136793",
787
+ "136817",
788
+ "136996",
789
+ "137104"
790
+ ],
791
+ "n_train": 194,
792
+ "n_val": 48
793
+ },
794
+ "fold_3": {
795
+ "train_patients": [
796
+ "101228",
797
+ "101627",
798
+ "102035",
799
+ "104252",
800
+ "104280",
801
+ "104420",
802
+ "104518",
803
+ "104520",
804
+ "104670",
805
+ "104797",
806
+ "104871",
807
+ "104899",
808
+ "104937",
809
+ "105302",
810
+ "105465",
811
+ "105549",
812
+ "105755",
813
+ "105917",
814
+ "106063",
815
+ "106200",
816
+ "106270",
817
+ "106536",
818
+ "106639",
819
+ "106780",
820
+ "106905",
821
+ "106976",
822
+ "107130",
823
+ "107233",
824
+ "107508",
825
+ "107539",
826
+ "107630",
827
+ "107739",
828
+ "108295",
829
+ "108344",
830
+ "108807",
831
+ "108975",
832
+ "109267",
833
+ "109654",
834
+ "109816",
835
+ "109923",
836
+ "110012",
837
+ "110157",
838
+ "110218",
839
+ "110280",
840
+ "110327",
841
+ "110540",
842
+ "111140",
843
+ "111189",
844
+ "111489",
845
+ "111852",
846
+ "112055",
847
+ "112378",
848
+ "112414",
849
+ "112657",
850
+ "112659",
851
+ "112765",
852
+ "113046",
853
+ "113394",
854
+ "113845",
855
+ "114058",
856
+ "114128",
857
+ "114266",
858
+ "114304",
859
+ "114836",
860
+ "114990",
861
+ "115588",
862
+ "115628",
863
+ "115788",
864
+ "115799",
865
+ "115991",
866
+ "116246",
867
+ "116268",
868
+ "116700",
869
+ "116768",
870
+ "116914",
871
+ "116937",
872
+ "117314",
873
+ "117385",
874
+ "117814",
875
+ "118018",
876
+ "118078",
877
+ "118409",
878
+ "118481",
879
+ "118605",
880
+ "118660",
881
+ "118719",
882
+ "118807",
883
+ "119095",
884
+ "119224",
885
+ "119730",
886
+ "120638",
887
+ "120749",
888
+ "120781",
889
+ "120857",
890
+ "121140",
891
+ "121404",
892
+ "121499",
893
+ "121620",
894
+ "121921",
895
+ "122000",
896
+ "122020",
897
+ "122288",
898
+ "122316",
899
+ "122762",
900
+ "122884",
901
+ "124187",
902
+ "125465",
903
+ "125567",
904
+ "125626",
905
+ "125798",
906
+ "126228",
907
+ "126396",
908
+ "126445",
909
+ "126465",
910
+ "126494",
911
+ "126523",
912
+ "126704",
913
+ "126768",
914
+ "126779",
915
+ "127096",
916
+ "127511",
917
+ "127513",
918
+ "127545",
919
+ "127758",
920
+ "127816",
921
+ "127870",
922
+ "128785",
923
+ "128901",
924
+ "129055",
925
+ "129100",
926
+ "129164",
927
+ "129637",
928
+ "129916",
929
+ "130308",
930
+ "130366",
931
+ "130371",
932
+ "130373",
933
+ "130556",
934
+ "130662",
935
+ "130801",
936
+ "131040",
937
+ "131235",
938
+ "131364",
939
+ "131444",
940
+ "131606",
941
+ "131636",
942
+ "131919",
943
+ "131924",
944
+ "132207",
945
+ "132282",
946
+ "132296",
947
+ "132371",
948
+ "132589",
949
+ "132597",
950
+ "132605",
951
+ "132812",
952
+ "132896",
953
+ "132920",
954
+ "133338",
955
+ "133340",
956
+ "133562",
957
+ "133710",
958
+ "133814",
959
+ "133850",
960
+ "133886",
961
+ "133934",
962
+ "133946",
963
+ "134197",
964
+ "134555",
965
+ "134654",
966
+ "134728",
967
+ "134955",
968
+ "135467",
969
+ "135628",
970
+ "135687",
971
+ "135695",
972
+ "135733",
973
+ "135855",
974
+ "136105",
975
+ "136144",
976
+ "136175",
977
+ "136310",
978
+ "136382",
979
+ "136589",
980
+ "136793",
981
+ "136817",
982
+ "136966",
983
+ "136996",
984
+ "137104",
985
+ "137168",
986
+ "137508",
987
+ "137617",
988
+ "137624",
989
+ "137675"
990
+ ],
991
+ "val_patients": [
992
+ "102313",
993
+ "104447",
994
+ "104453",
995
+ "104810",
996
+ "105074",
997
+ "105978",
998
+ "107455",
999
+ "107966",
1000
+ "107997",
1001
+ "108444",
1002
+ "108726",
1003
+ "109141",
1004
+ "109395",
1005
+ "109944",
1006
+ "110497",
1007
+ "111691",
1008
+ "112776",
1009
+ "114525",
1010
+ "114585",
1011
+ "114903",
1012
+ "115841",
1013
+ "116236",
1014
+ "116577",
1015
+ "118450",
1016
+ "118755",
1017
+ "121804",
1018
+ "123575",
1019
+ "124899",
1020
+ "125198",
1021
+ "126542",
1022
+ "127897",
1023
+ "129739",
1024
+ "130214",
1025
+ "130282",
1026
+ "130402",
1027
+ "131231",
1028
+ "131494",
1029
+ "131792",
1030
+ "132155",
1031
+ "133196",
1032
+ "134032",
1033
+ "134919",
1034
+ "135503",
1035
+ "135697",
1036
+ "135725",
1037
+ "135830",
1038
+ "136104",
1039
+ "136220"
1040
+ ],
1041
+ "n_train": 194,
1042
+ "n_val": 48
1043
+ },
1044
+ "fold_4": {
1045
+ "train_patients": [
1046
+ "101228",
1047
+ "102035",
1048
+ "102313",
1049
+ "104252",
1050
+ "104280",
1051
+ "104420",
1052
+ "104447",
1053
+ "104453",
1054
+ "104518",
1055
+ "104520",
1056
+ "104670",
1057
+ "104797",
1058
+ "104810",
1059
+ "104899",
1060
+ "104937",
1061
+ "105074",
1062
+ "105302",
1063
+ "105465",
1064
+ "105917",
1065
+ "105978",
1066
+ "106063",
1067
+ "106200",
1068
+ "106270",
1069
+ "106536",
1070
+ "106639",
1071
+ "106780",
1072
+ "106905",
1073
+ "106976",
1074
+ "107130",
1075
+ "107233",
1076
+ "107455",
1077
+ "107508",
1078
+ "107539",
1079
+ "107739",
1080
+ "107966",
1081
+ "107997",
1082
+ "108295",
1083
+ "108344",
1084
+ "108444",
1085
+ "108726",
1086
+ "108807",
1087
+ "108975",
1088
+ "109141",
1089
+ "109395",
1090
+ "109816",
1091
+ "109944",
1092
+ "110157",
1093
+ "110218",
1094
+ "110497",
1095
+ "110540",
1096
+ "111140",
1097
+ "111189",
1098
+ "111691",
1099
+ "111852",
1100
+ "112055",
1101
+ "112378",
1102
+ "112414",
1103
+ "112657",
1104
+ "112659",
1105
+ "112776",
1106
+ "113046",
1107
+ "113845",
1108
+ "114058",
1109
+ "114128",
1110
+ "114266",
1111
+ "114304",
1112
+ "114525",
1113
+ "114585",
1114
+ "114836",
1115
+ "114903",
1116
+ "114990",
1117
+ "115588",
1118
+ "115628",
1119
+ "115841",
1120
+ "115991",
1121
+ "116236",
1122
+ "116268",
1123
+ "116577",
1124
+ "116700",
1125
+ "116768",
1126
+ "116914",
1127
+ "116937",
1128
+ "117385",
1129
+ "117814",
1130
+ "118409",
1131
+ "118450",
1132
+ "118660",
1133
+ "118719",
1134
+ "118755",
1135
+ "118807",
1136
+ "119095",
1137
+ "119224",
1138
+ "119730",
1139
+ "120638",
1140
+ "120781",
1141
+ "120857",
1142
+ "121140",
1143
+ "121404",
1144
+ "121620",
1145
+ "121804",
1146
+ "121921",
1147
+ "122000",
1148
+ "122020",
1149
+ "122288",
1150
+ "122316",
1151
+ "123575",
1152
+ "124187",
1153
+ "124899",
1154
+ "125198",
1155
+ "125465",
1156
+ "125567",
1157
+ "125626",
1158
+ "125798",
1159
+ "126228",
1160
+ "126396",
1161
+ "126445",
1162
+ "126465",
1163
+ "126494",
1164
+ "126523",
1165
+ "126542",
1166
+ "126704",
1167
+ "126768",
1168
+ "126779",
1169
+ "127511",
1170
+ "127545",
1171
+ "127758",
1172
+ "127816",
1173
+ "127870",
1174
+ "127897",
1175
+ "128901",
1176
+ "129055",
1177
+ "129164",
1178
+ "129637",
1179
+ "129739",
1180
+ "129916",
1181
+ "130214",
1182
+ "130282",
1183
+ "130308",
1184
+ "130366",
1185
+ "130373",
1186
+ "130402",
1187
+ "130556",
1188
+ "130662",
1189
+ "131040",
1190
+ "131231",
1191
+ "131235",
1192
+ "131364",
1193
+ "131494",
1194
+ "131606",
1195
+ "131636",
1196
+ "131792",
1197
+ "131919",
1198
+ "131924",
1199
+ "132155",
1200
+ "132371",
1201
+ "132597",
1202
+ "132605",
1203
+ "132812",
1204
+ "132896",
1205
+ "132920",
1206
+ "133196",
1207
+ "133338",
1208
+ "133340",
1209
+ "133886",
1210
+ "133934",
1211
+ "133946",
1212
+ "134032",
1213
+ "134197",
1214
+ "134555",
1215
+ "134728",
1216
+ "134919",
1217
+ "135503",
1218
+ "135628",
1219
+ "135687",
1220
+ "135695",
1221
+ "135697",
1222
+ "135725",
1223
+ "135733",
1224
+ "135830",
1225
+ "135855",
1226
+ "136104",
1227
+ "136105",
1228
+ "136144",
1229
+ "136220",
1230
+ "136310",
1231
+ "136382",
1232
+ "136589",
1233
+ "136793",
1234
+ "136817",
1235
+ "136996",
1236
+ "137104",
1237
+ "137168",
1238
+ "137617",
1239
+ "137624"
1240
+ ],
1241
+ "val_patients": [
1242
+ "101627",
1243
+ "104871",
1244
+ "105549",
1245
+ "105755",
1246
+ "107630",
1247
+ "109267",
1248
+ "109654",
1249
+ "109923",
1250
+ "110012",
1251
+ "110280",
1252
+ "110327",
1253
+ "111489",
1254
+ "112765",
1255
+ "113394",
1256
+ "115788",
1257
+ "115799",
1258
+ "116246",
1259
+ "117314",
1260
+ "118018",
1261
+ "118078",
1262
+ "118481",
1263
+ "118605",
1264
+ "120749",
1265
+ "121499",
1266
+ "122762",
1267
+ "122884",
1268
+ "127096",
1269
+ "127513",
1270
+ "128785",
1271
+ "129100",
1272
+ "130371",
1273
+ "130801",
1274
+ "131444",
1275
+ "132207",
1276
+ "132282",
1277
+ "132296",
1278
+ "132589",
1279
+ "133562",
1280
+ "133710",
1281
+ "133814",
1282
+ "133850",
1283
+ "134654",
1284
+ "134955",
1285
+ "135467",
1286
+ "136175",
1287
+ "136966",
1288
+ "137508",
1289
+ "137675"
1290
+ ],
1291
+ "n_train": 194,
1292
+ "n_val": 48
1293
+ }
1294
+ }
1295
+ }
models/for_GM/model_training_scripts/p1_compute_class_weights.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ P1 Article - Compute Class Weights from Training Data
3
+ Utility script to calculate inverse frequency weights for class balancing
4
+
5
+ Usage:
6
+ python p1_compute_class_weights.py --fold 0 --scenario binary --preprocessing standard
7
+
8
+ Output:
9
+ Saves class weights to JSON file for reproducibility
10
+ Prints weights for use in training
11
+
12
+ """
13
+
14
+ import numpy as np
15
+ import json
16
+ from pathlib import Path
17
+ from tqdm import tqdm
18
+ import argparse
19
+
20
+ # Import data loader
21
+ from p1_data_loader import DataConfig, P1DataLoader
22
+
23
+
24
+ def compute_class_frequencies(dataset, num_classes, total_samples=None):
25
+ """
26
+ Compute class frequencies from dataset
27
+
28
+ Args:
29
+ dataset: TensorFlow dataset yielding (paired_input, target_mask)
30
+ num_classes: Number of classes (2)
31
+ total_samples: Total number of samples (for progress bar)
32
+
33
+ Returns:
34
+ class_pixel_counts: Array of pixel counts per class
35
+ total_pixels: Total number of pixels analyzed
36
+ """
37
+ class_pixel_counts = np.zeros(num_classes, dtype=np.int64)
38
+ total_pixels = 0
39
+
40
+ print(f"Computing class frequencies for {num_classes}-class scenario...")
41
+
42
+ iterator = tqdm(dataset, total=total_samples, desc="Processing") if total_samples else dataset
43
+
44
+ for paired_input, target_mask, _, _ in iterator:
45
+ # target_mask shape: (batch_size, 256, 256)
46
+ masks = target_mask.numpy()
47
+
48
+ for mask in masks:
49
+ # Count pixels for each class
50
+ for class_id in range(num_classes):
51
+ class_pixel_counts[class_id] += np.sum(mask == class_id)
52
+
53
+ total_pixels += mask.size
54
+
55
+ return class_pixel_counts, total_pixels
56
+
57
+
58
+ def compute_inverse_frequency_weights(class_pixel_counts, num_classes):
59
+ """
60
+ Compute inverse frequency weights with normalization
61
+
62
+ Args:
63
+ class_pixel_counts: Array of pixel counts per class
64
+ num_classes: Number of classes
65
+
66
+ Returns:
67
+ class_weights: Normalized inverse frequency weights
68
+ class_frequencies: Class frequencies (for reference)
69
+ """
70
+ total_pixels = np.sum(class_pixel_counts)
71
+
72
+ # Class frequencies
73
+ class_frequencies = class_pixel_counts / total_pixels
74
+
75
+ # Inverse frequency (with small epsilon to avoid division by zero)
76
+ epsilon = 1e-6
77
+ inverse_freq = 1.0 / (class_frequencies + epsilon)
78
+
79
+ # Normalize weights to sum = num_classes
80
+ # This keeps weights in a reasonable range while maintaining relative importance
81
+ class_weights = inverse_freq / np.sum(inverse_freq) * num_classes
82
+
83
+ return class_weights, class_frequencies
84
+
85
+
86
+ def compute_and_save_class_weights(fold_id, class_scenario, preprocessing,
87
+ output_dir='class_weights_gm'):
88
+ """
89
+ Compute class weights for a specific fold and scenario
90
+
91
+ Args:
92
+ fold_id: Fold number (0-4)
93
+ class_scenario: 'binary'
94
+ preprocessing: 'standard' or 'zoomed'
95
+ output_dir: Directory to save weights
96
+
97
+ Returns:
98
+ Dictionary with weights and statistics
99
+ """
100
+ print("\n" + "="*70)
101
+ print(f"COMPUTING CLASS WEIGHTS")
102
+ print("="*70)
103
+ print(f"Fold: {fold_id}")
104
+ print(f"Scenario: {class_scenario}")
105
+ print(f"Preprocessing: {preprocessing}")
106
+ print("="*70 + "\n")
107
+
108
+ # Initialize data loader
109
+ config = DataConfig()
110
+ data_loader = P1DataLoader(config)
111
+
112
+ # Determine number of classes
113
+ num_classes = 2 if class_scenario == 'binary' else 2
114
+
115
+ # Load training dataset
116
+ print("Loading training dataset...")
117
+ train_dataset = data_loader.create_dataset_for_fold(
118
+ fold_id=fold_id,
119
+ split='train',
120
+ preprocessing=preprocessing,
121
+ class_scenario=class_scenario,
122
+ batch_size=4, # Larger batch for faster processing
123
+ shuffle=False # No need to shuffle for counting
124
+ )
125
+
126
+ # Get dataset size
127
+ train_size = sum(1 for _ in train_dataset)
128
+ print(f"Training samples: {train_size}")
129
+
130
+ # Compute class frequencies
131
+ class_pixel_counts, total_pixels = compute_class_frequencies(
132
+ train_dataset, num_classes, train_size
133
+ )
134
+
135
+ # Compute inverse frequency weights
136
+ class_weights, class_frequencies = compute_inverse_frequency_weights(
137
+ class_pixel_counts, num_classes
138
+ )
139
+
140
+ # Print results
141
+ print("\n" + "="*70)
142
+ print("RESULTS")
143
+ print("="*70)
144
+
145
+ class_names = {
146
+ 2: ['Background', 'Specialized GM']
147
+ }
148
+
149
+ print(f"\nTotal pixels analyzed: {total_pixels:,}")
150
+ print(f"\nClass Statistics:")
151
+ print("-" * 70)
152
+
153
+ for i in range(num_classes):
154
+ print(f"Class {i} ({class_names[num_classes][i]}):")
155
+ print(f" Pixel count: {class_pixel_counts[i]:,}")
156
+ print(f" Frequency: {class_frequencies[i]:.6f} ({class_frequencies[i]*100:.2f}%)")
157
+ print(f" Weight: {class_weights[i]:.4f}")
158
+ print()
159
+
160
+ # Save to JSON
161
+ output_path = Path(output_dir)
162
+ output_path.mkdir(exist_ok=True)
163
+
164
+ results = {
165
+ 'fold_id': fold_id,
166
+ 'class_scenario': class_scenario,
167
+ 'preprocessing': preprocessing,
168
+ 'num_classes': num_classes,
169
+ 'total_pixels': int(total_pixels),
170
+ 'class_pixel_counts': class_pixel_counts.tolist(),
171
+ 'class_frequencies': class_frequencies.tolist(),
172
+ 'class_weights': class_weights.tolist(),
173
+ 'class_names': class_names[num_classes]
174
+ }
175
+
176
+ filename = f"class_weights_fold{fold_id}_{preprocessing}_{class_scenario}.json"
177
+ filepath = output_path / filename
178
+
179
+ with open(filepath, 'w') as f:
180
+ json.dump(results, f, indent=2)
181
+
182
+ print("="*70)
183
+ print(f"✅ Class weights saved to: {filepath}")
184
+ print("="*70)
185
+
186
+ # Print weights in format ready for code
187
+ print("\nFor use in training script:")
188
+ print("-" * 70)
189
+ print(f"class_weights = tf.constant({class_weights.tolist()}, dtype=tf.float32)")
190
+ print()
191
+
192
+ return results
193
+
194
+
195
+ def compute_all_scenarios_for_fold(fold_id):
196
+ """
197
+ Compute class weights for all 2 scenarios of a given fold
198
+
199
+ Args:
200
+ fold_id: Fold number (0-4)
201
+ """
202
+ scenarios = [
203
+ {'preprocessing': 'standard', 'class_scenario': 'binary'},
204
+ {'preprocessing': 'zoomed', 'class_scenario': 'binary'},
205
+ ]
206
+
207
+ all_results = {}
208
+
209
+ for scenario in scenarios:
210
+ results = compute_and_save_class_weights(
211
+ fold_id=fold_id,
212
+ class_scenario=scenario['class_scenario'],
213
+ preprocessing=scenario['preprocessing']
214
+ )
215
+
216
+ key = f"{scenario['preprocessing']}_{scenario['class_scenario']}"
217
+ all_results[key] = results
218
+
219
+ print("\n" + "="*70 + "\n")
220
+
221
+ return all_results
222
+
223
+
224
+ def load_class_weights(fold_id, class_scenario, preprocessing, weights_dir='class_weights_gm'):
225
+ """
226
+ Load previously computed class weights
227
+
228
+ Args:
229
+ fold_id: Fold number (0-4)
230
+ class_scenario: 'binary'
231
+ preprocessing: 'standard' or 'zoomed'
232
+ weights_dir: Directory containing weights files
233
+
234
+ Returns:
235
+ class_weights: NumPy array of weights
236
+ """
237
+ weights_path = Path(weights_dir)
238
+ filename = f"class_weights_fold{fold_id}_{preprocessing}_{class_scenario}.json"
239
+ filepath = weights_path / filename
240
+
241
+ if not filepath.exists():
242
+ raise FileNotFoundError(
243
+ f"Class weights not found: {filepath}\n"
244
+ f"Run compute_and_save_class_weights() first."
245
+ )
246
+
247
+ with open(filepath, 'r') as f:
248
+ results = json.load(f)
249
+
250
+ class_weights = np.array(results['class_weights'], dtype=np.float32)
251
+
252
+ return class_weights
253
+
254
+
255
+ def main():
256
+ """Main entry point with argument parsing"""
257
+ parser = argparse.ArgumentParser(
258
+ description='Compute class weights from training data',
259
+ formatter_class=argparse.RawDescriptionHelpFormatter,
260
+ epilog="""
261
+ Examples:
262
+ # Single scenario
263
+ python p1_compute_class_weights.py --fold 0 --scenario binary --preprocessing standard
264
+
265
+ # All scenarios for one fold
266
+ python p1_compute_class_weights.py --fold 0 --all
267
+
268
+ # All folds (for completeness)
269
+ python p1_compute_class_weights.py --all-folds
270
+ """
271
+ )
272
+
273
+ parser.add_argument(
274
+ '--fold',
275
+ type=int,
276
+ choices=[0, 1, 2, 3, 4],
277
+ help='Fold number (0-4)'
278
+ )
279
+
280
+ parser.add_argument(
281
+ '--scenario',
282
+ type=str,
283
+ choices=['binary'],
284
+ help='Class scenario'
285
+ )
286
+
287
+ parser.add_argument(
288
+ '--preprocessing',
289
+ type=str,
290
+ choices=['standard', 'zoomed'],
291
+ help='Preprocessing type'
292
+ )
293
+
294
+ parser.add_argument(
295
+ '--all',
296
+ action='store_true',
297
+ help='Compute for all scenarios of specified fold'
298
+ )
299
+
300
+ parser.add_argument(
301
+ '--all-folds',
302
+ action='store_true',
303
+ help='Compute for all scenarios of all folds'
304
+ )
305
+
306
+ args = parser.parse_args()
307
+
308
+ # Validate arguments
309
+ if args.all_folds:
310
+ # Compute for all folds
311
+ for fold_id in range(5):
312
+ print(f"\n{'='*70}")
313
+ print(f"PROCESSING FOLD {fold_id}")
314
+ print(f"{'='*70}\n")
315
+ compute_all_scenarios_for_fold(fold_id)
316
+
317
+ elif args.all:
318
+ # Compute all scenarios for one fold
319
+ if args.fold is None:
320
+ parser.error("--fold is required when using --all")
321
+ compute_all_scenarios_for_fold(args.fold)
322
+
323
+ else:
324
+ # Compute single scenario
325
+ if args.fold is None or args.scenario is None or args.preprocessing is None:
326
+ parser.error("--fold, --scenario, and --preprocessing are required")
327
+
328
+ compute_and_save_class_weights(
329
+ fold_id=args.fold,
330
+ class_scenario=args.scenario,
331
+ preprocessing=args.preprocessing
332
+ )
333
+
334
+
335
+ if __name__ == "__main__":
336
+ main()
models/for_GM/model_training_scripts/p1_data_loader.py ADDED
@@ -0,0 +1,847 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ P1 & P4 Articles - Data Loading System
3
+
4
+ Complete implementation for brain segmentation experiments
5
+
6
+ Specialized Gray Matter (GM) Segmentation with U-Net Models - Journal Paper Implementation
7
+ Binary segmentation: Background vs Specialized GM
8
+ Professional results saving and visualization for publication
9
+
10
+ This relates to our articles:
11
+ "Specialized gray matter segmentation via a generative adversarial network:
12
+ application on brain white matter hyperintensities classification"
13
+
14
+ "Deep Learning-Based Neuroanatomical Profiling Reveals Detailed Brain Changes:
15
+ A Large-Scale Multiple Sclerosis Study"
16
+
17
+ Features:
18
+ - Load FLAIR images and individual mask files from Cohort directory
19
+ - Support both Local_SAI_GM_sp dataset
20
+ - Handle standard and zoomed preprocessing variants
21
+ - Combine masks into 2-class format
22
+ - Create paired inputs: [FLAIR | mask] concatenated (256x512)
23
+ - Patient-stratified K-fold cross-validation
24
+ - TensorFlow dataset creation with proper batching
25
+
26
+ Authors:
27
+ "Mahdi Bashiri Bawil, Mousa Shamsi, Abolhassan Shakeri Bavil"
28
+
29
+ Developer:
30
+ "Mahdi Bashiri Bawil"
31
+ """
32
+
33
+ import numpy as np
34
+ import os
35
+ from pathlib import Path
36
+ from typing import Tuple, List, Dict, Optional
37
+ import json
38
+ from sklearn.model_selection import KFold
39
+ from tqdm import tqdm
40
+ import cv2 as cv
41
+
42
+ # Deep Learning
43
+ import tensorflow as tf
44
+
45
+
46
+ ###################### Configuration ######################
47
+
48
+ class DataConfig:
49
+ """Data configuration for P4 experiments"""
50
+
51
+ def __init__(self):
52
+ # Base paths
53
+ self.cohort_dir = Path("/mnt/e/MBashiri/ours_articles/Paper#2/Data/Cohort") # CHANGE THIS to your actual path of Data Cohort
54
+
55
+ # Dataset configurations
56
+ self.datasets = {
57
+ 'Local_SAI_GM_sp': {
58
+ 'base_path': self.cohort_dir / 'Local_SAI_GM_sp',
59
+ 'slice_range': (1, 20), # inclusive range 9,15
60
+ 'patient_prefix_length': 6 # "101228"
61
+ }
62
+ }
63
+
64
+ # Preprocessing variants
65
+ self.preprocessing_types = ['standard', 'zoomed']
66
+
67
+ # Class scenarios
68
+ self.class_scenarios = {
69
+ 'binary': {
70
+ 'num_classes': 2,
71
+ 'class_names': ['Background', 'Specialized GM'],
72
+ 'description': 'Binary: Background, Specialized GM',
73
+ 'class_mapping': {
74
+ 'background': 0,
75
+ 'specialized_gm': 1,
76
+ }
77
+ }
78
+ }
79
+
80
+ # K-fold parameters
81
+ self.k_folds = 5
82
+ self.test_split = 0.1 # 10% for test set
83
+ self.random_state = 42
84
+
85
+ # Image parameters
86
+ self.target_size = (256, 256)
87
+ self.paired_width = 512 # FLAIR (256) + mask (256)
88
+
89
+ # Paths for splits
90
+ self.splits_dir = Path("data_splits_sp_gm")
91
+ self.splits_file = self.splits_dir / "SP_GM_fold_assignments.json"
92
+
93
+
94
+ ###################### Helper Functions ######################
95
+
96
+ def extract_patient_id(filename: str, prefix_length: int = 6) -> str:
97
+ """
98
+ Extract patient ID from filename
99
+
100
+ Args:
101
+ filename: e.g., "101228_5.npy" or "c01p01_25.png"
102
+ prefix_length: Number of characters in patient ID
103
+
104
+ Returns:
105
+ Patient ID: e.g., "101228" or "c01p01"
106
+ """
107
+ return filename.split('_')[0][:prefix_length]
108
+
109
+
110
+ def extract_slice_number(filename: str) -> int:
111
+ """
112
+ Extract slice number from filename
113
+
114
+ Args:
115
+ filename: e.g., "101228_5.npy" or "c01p01_25.png"
116
+
117
+ Returns:
118
+ Slice number as integer
119
+ """
120
+ # Get the part before file extension
121
+ basename = filename.split('.')[0]
122
+ # Get the last part after splitting by '_'
123
+ slice_num = basename.split('_')[-1]
124
+ return int(slice_num)
125
+
126
+
127
+ def load_flair_image(flair_path: Path, normalize: bool = False, of_z_score: bool = False) -> np.ndarray:
128
+ """
129
+ Load FLAIR image (.png format)
130
+
131
+ Args:
132
+ flair_path: Path to .png file
133
+ normalize: Whether to apply z-score normalization
134
+
135
+ Returns:
136
+ FLAIR image (256, 256, 1) as float32
137
+ """
138
+ if of_z_score:
139
+ # Load NPY: the already z-scored FLAIR image data
140
+ flair = np.load(str(flair_path).replace('.png','.npy')).astype(np.float32)
141
+ else:
142
+ # Load PNG as grayscale
143
+ flair = cv.imread(str(flair_path), cv.IMREAD_GRAYSCALE).astype(np.float32)
144
+
145
+ # Normalize to [-1, 1]:
146
+ flair = (flair - np.min(flair)) / (np.max(flair) - np.min(flair))
147
+ flair = (2 * flair) - 1
148
+
149
+ # Ensure correct shape
150
+ if len(flair.shape) == 2:
151
+ flair = np.expand_dims(flair, axis=-1)
152
+
153
+ # Additional normalization if needed (should already be normalized)
154
+ if normalize and (np.std(flair) > 2.0 or np.abs(np.mean(flair)) > 1.0):
155
+ # Re-normalize if values seem off
156
+ flair = (flair - np.mean(flair)) / (np.std(flair) + 1e-7)
157
+
158
+ return flair
159
+
160
+
161
+ def load_mask_image(mask_path: Path) -> np.ndarray:
162
+ """
163
+ Load mask image (.png format)
164
+
165
+ Args:
166
+ mask_path: Path to .png file
167
+
168
+ Returns:
169
+ Binary mask (256, 256) as uint8
170
+ """
171
+ # Load PNG as grayscale
172
+ mask = cv.imread(str(mask_path), cv.IMREAD_GRAYSCALE)
173
+
174
+ if mask is None:
175
+ raise FileNotFoundError(f"Could not load mask: {mask_path}")
176
+
177
+ # Binarize (any non-zero value becomes 1)
178
+ mask = (mask > 0).astype(np.uint8)
179
+
180
+ return mask
181
+
182
+
183
+ def combine_masks(gm_mask: np.ndarray,
184
+ class_scenario: str,
185
+ preprocess: bool = False) -> np.ndarray:
186
+ """
187
+ Combine individual masks into multi-class format
188
+
189
+ Args:
190
+ gm_mask: Ventricles mask (256, 256)
191
+ class_scenario: 'binary'
192
+ preprocess: Boolean turning the morphological preprocessing on or off
193
+
194
+ Returns:
195
+ Combined mask (256, 256) with class labels
196
+ """
197
+ if preprocess:
198
+ from skimage.morphology import remove_small_objects, binary_erosion, binary_closing, binary_opening, disk, binary_dilation
199
+ min_object_size = 5
200
+ closing_kernel_size = 2
201
+ dilation_kernel_size = 1
202
+
203
+ gm_mask = gm_mask > 0
204
+
205
+ gm_mask = binary_closing(gm_mask, disk(closing_kernel_size))
206
+ gm_mask = binary_erosion(gm_mask, disk(dilation_kernel_size))
207
+ gm_mask = remove_small_objects(gm_mask, min_size=min_object_size)
208
+
209
+ # Class 0: Background (default)
210
+ # Class 1: Specialized GM
211
+ combined = np.zeros_like(gm_mask, dtype=np.uint8)
212
+ combined[gm_mask>0] = 1
213
+
214
+ return combined
215
+
216
+
217
+ def is_valid_slice(gm_mask: np.ndarray) -> bool:
218
+ """
219
+ Check if slice has at least one non-empty mask
220
+
221
+ Args:
222
+ gm_mask: Specialized GM mask (256, 256)
223
+
224
+ Returns:
225
+ True if at least one mask has non-zero pixels
226
+ """
227
+ has_specialized_gm = np.sum(gm_mask) > 50
228
+
229
+ # Valid if ANY mask has content
230
+ return True # or has_specialized_gm
231
+
232
+
233
+ def create_paired_input(flair: np.ndarray,
234
+ mask: np.ndarray,
235
+ brain_mask: np.ndarray,
236
+ num_classes: np.ndarray,
237
+ if_bet=False) -> np.ndarray:
238
+ """
239
+ Create paired input: [FLAIR | mask] concatenated horizontally
240
+
241
+ Args:
242
+ flair: FLAIR image (256, 256, 1) float32
243
+ mask: Combined mask (256, 256) uint8
244
+
245
+ Returns:
246
+ Paired image (256, 512, 1) float32
247
+ """
248
+ # Binarize (any non-zero value becomes 1)
249
+ brain_mask = brain_mask > 0
250
+
251
+ # Brain extraction
252
+ if if_bet:
253
+ # print("\n\t Doing THEEEEEEEEE BET")
254
+ flair[~brain_mask] = np.min(flair)
255
+ mask[~brain_mask] = 0
256
+
257
+ # Ensure flair is 3D
258
+ if len(flair.shape) == 2:
259
+ flair = np.expand_dims(flair, axis=-1)
260
+
261
+ # Convert mask to float and normalize to [0, 1] range for consistency
262
+
263
+ max_class = num_classes
264
+ mask_normalized = mask.astype(np.float32)
265
+ if max_class > 0:
266
+ mask_normalized = mask_normalized / max_class
267
+ mask_normalized = (2 * mask_normalized) - 1
268
+
269
+ mask_3d = np.expand_dims(mask_normalized, axis=-1)
270
+
271
+ # Concatenate horizontally: [FLAIR | mask]
272
+ paired = np.concatenate([flair, mask_3d], axis=1) # (256, 512, 1)
273
+
274
+ return paired, mask
275
+
276
+
277
+ ###################### Patient Stratified Splitting ######################
278
+
279
+ class PatientStratifiedSplitter:
280
+ """
281
+ Create patient-stratified train/val/test splits
282
+ Similar to P6 implementation but adapted for P1 data structure
283
+ """
284
+
285
+ def __init__(self, config: DataConfig):
286
+ self.config = config
287
+ self.config.splits_dir.mkdir(exist_ok=True)
288
+
289
+ def collect_all_patients(self) -> Dict[str, List[str]]:
290
+ """
291
+ Collect all unique patient IDs from both datasets
292
+
293
+ Returns:
294
+ Dictionary mapping dataset_name -> list of patient IDs
295
+ """
296
+ all_patients = {}
297
+
298
+ for dataset_name, dataset_config in self.config.datasets.items():
299
+ patients = set()
300
+
301
+ # Path to FLAIR images (standard preprocessing)
302
+ flair_dir = dataset_config['base_path'] / 'FLAIR' / 'Preprocessed' / 'images'
303
+
304
+ if not flair_dir.exists():
305
+ print(f"Warning: {flair_dir} does not exist. Skipping {dataset_name}.")
306
+ continue
307
+
308
+ # Collect all .png files
309
+ for flair_file in flair_dir.glob('*.png'):
310
+ patient_id = extract_patient_id(
311
+ flair_file.name,
312
+ dataset_config['patient_prefix_length']
313
+ )
314
+ patients.add(patient_id)
315
+
316
+ all_patients[dataset_name] = sorted(list(patients))
317
+ print(f"{dataset_name}: {len(all_patients[dataset_name])} patients")
318
+
319
+ return all_patients
320
+
321
+ def create_patient_stratified_splits(self,
322
+ save: bool = True) -> Dict:
323
+ """
324
+ Create patient-stratified K-fold splits
325
+
326
+ Returns:
327
+ Dictionary containing fold assignments
328
+ """
329
+ all_patients = self.collect_all_patients()
330
+
331
+ # Combine patients from both datasets
332
+ combined_patients = []
333
+ for dataset_name, patients in all_patients.items():
334
+ combined_patients.extend(patients)
335
+
336
+ combined_patients = np.array(combined_patients)
337
+ total_patients = len(combined_patients)
338
+
339
+ print(f"\nTotal unique patients: {total_patients}")
340
+
341
+ # Step 1: Split into train+val (80%) and test (20%)
342
+ np.random.seed(self.config.random_state)
343
+ test_size = int(total_patients * self.config.test_split)
344
+
345
+ test_indices = np.random.choice(
346
+ total_patients,
347
+ size=test_size,
348
+ replace=False
349
+ )
350
+
351
+ test_patients = combined_patients[test_indices]
352
+ train_val_indices = np.setdiff1d(np.arange(total_patients), test_indices)
353
+ train_val_patients = combined_patients[train_val_indices]
354
+
355
+ print(f"Test patients: {len(test_patients)}")
356
+ print(f"Train+Val patients: {len(train_val_patients)}")
357
+
358
+ # Step 2: Create K-fold splits on train+val patients
359
+ kfold = KFold(
360
+ n_splits=self.config.k_folds,
361
+ shuffle=True,
362
+ random_state=self.config.random_state
363
+ )
364
+
365
+ fold_assignments = {
366
+ 'metadata': {
367
+ 'total_patients': total_patients,
368
+ 'test_patients': len(test_patients),
369
+ 'trainval_patients': len(train_val_patients),
370
+ 'n_folds': self.config.k_folds,
371
+ 'random_seed': self.config.random_state,
372
+ 'datasets': list(all_patients.keys())
373
+ },
374
+ 'test_set': {
375
+ 'patients': test_patients.tolist(),
376
+ 'n_patients': len(test_patients)
377
+ },
378
+ 'folds': {}
379
+ }
380
+
381
+ for fold_idx, (train_idx, val_idx) in enumerate(kfold.split(train_val_patients)):
382
+ train_patients_fold = train_val_patients[train_idx]
383
+ val_patients_fold = train_val_patients[val_idx]
384
+
385
+ fold_assignments['folds'][f'fold_{fold_idx}'] = {
386
+ 'train_patients': train_patients_fold.tolist(),
387
+ 'val_patients': val_patients_fold.tolist(),
388
+ 'n_train': len(train_patients_fold),
389
+ 'n_val': len(val_patients_fold)
390
+ }
391
+
392
+ print(f"Fold {fold_idx}: Train={len(train_patients_fold)}, Val={len(val_patients_fold)}")
393
+
394
+ # Save to JSON
395
+ if save:
396
+ with open(self.config.splits_file, 'w') as f:
397
+ json.dump(fold_assignments, f, indent=2)
398
+ print(f"\n✅ Fold assignments saved to: {self.config.splits_file}")
399
+
400
+ return fold_assignments
401
+
402
+ def load_fold_assignments(self) -> Dict:
403
+ """Load existing fold assignments from JSON"""
404
+ if not self.config.splits_file.exists():
405
+ raise FileNotFoundError(
406
+ f"Fold assignments not found: {self.config.splits_file}\n"
407
+ f"Run create_patient_stratified_splits() first."
408
+ )
409
+
410
+ with open(self.config.splits_file, 'r') as f:
411
+ fold_assignments = json.load(f)
412
+
413
+ return fold_assignments
414
+
415
+ def verify_patient_separation(self, fold_assignments: Dict) -> bool:
416
+ """
417
+ Verify no patient appears in multiple folds or in both train/val
418
+ Similar to P6's verification logic
419
+ """
420
+ print("\n" + "="*60)
421
+ print("VERIFYING PATIENT SEPARATION")
422
+ print("="*60)
423
+
424
+ all_issues = []
425
+ test_patients = set(fold_assignments['test_set']['patients'])
426
+
427
+ # Check 1: No patient in both test and train/val
428
+ for fold_name, fold_data in fold_assignments['folds'].items():
429
+ train_patients = set(fold_data['train_patients'])
430
+ val_patients = set(fold_data['val_patients'])
431
+
432
+ test_train_overlap = test_patients.intersection(train_patients)
433
+ test_val_overlap = test_patients.intersection(val_patients)
434
+
435
+ if test_train_overlap:
436
+ issue = f"{fold_name}: Test-Train overlap: {test_train_overlap}"
437
+ all_issues.append(issue)
438
+ print(f"❌ {issue}")
439
+
440
+ if test_val_overlap:
441
+ issue = f"{fold_name}: Test-Val overlap: {test_val_overlap}"
442
+ all_issues.append(issue)
443
+ print(f"❌ {issue}")
444
+
445
+ # Check 2: No patient in both train and val within same fold
446
+ for fold_name, fold_data in fold_assignments['folds'].items():
447
+ train_patients = set(fold_data['train_patients'])
448
+ val_patients = set(fold_data['val_patients'])
449
+
450
+ train_val_overlap = train_patients.intersection(val_patients)
451
+ if train_val_overlap:
452
+ issue = f"{fold_name}: Train-Val overlap: {train_val_overlap}"
453
+ all_issues.append(issue)
454
+ print(f"❌ {issue}")
455
+
456
+ # Check 3: Each patient in validation exactly once
457
+ all_val_patients = []
458
+ for fold_data in fold_assignments['folds'].values():
459
+ all_val_patients.extend(fold_data['val_patients'])
460
+
461
+ val_patient_counts = {}
462
+ for patient in all_val_patients:
463
+ val_patient_counts[patient] = val_patient_counts.get(patient, 0) + 1
464
+
465
+ for patient, count in val_patient_counts.items():
466
+ if count != 1:
467
+ issue = f"Patient {patient} in validation {count} times (should be 1)"
468
+ all_issues.append(issue)
469
+ print(f"❌ {issue}")
470
+
471
+ if not all_issues:
472
+ print("✅ All patient separation checks passed")
473
+ print("✅ No data leakage detected")
474
+ return True
475
+ else:
476
+ print(f"\n❌ Found {len(all_issues)} issues")
477
+ return False
478
+
479
+
480
+ ###################### Data Loader ######################
481
+
482
+ class P1DataLoader:
483
+ """
484
+ Main data loader for P1 experiments
485
+ Handles loading FLAIR and masks, creating paired inputs, TensorFlow datasets
486
+ """
487
+
488
+ def __init__(self, config: DataConfig):
489
+ self.config = config
490
+
491
+ def get_file_paths(self,
492
+ patient_id: str,
493
+ slice_num: int,
494
+ dataset_name: str,
495
+ preprocessing: str) -> Dict[str, Path]:
496
+ """
497
+ Construct file paths for a given patient-slice
498
+
499
+ Args:
500
+ patient_id: e.g., "101228" or "c01p01"
501
+ slice_num: Slice number
502
+ dataset_name: 'Local_SAI_GM_sp'
503
+ preprocessing: 'standard' or 'zoomed'
504
+
505
+ Returns:
506
+ Dictionary with paths to FLAIR and mask files
507
+ """
508
+ dataset_config = self.config.datasets[dataset_name]
509
+ base_path = dataset_config['base_path']
510
+
511
+ # Determine subdirectory based on preprocessing
512
+ if preprocessing == 'standard':
513
+ flair_subdir = 'images'
514
+ gt_subdir = 'images'
515
+ else: # zoomed
516
+ flair_subdir = 'zoomed/images'
517
+ gt_subdir = 'zoomed/images'
518
+
519
+ # Construct paths
520
+ flair_path = base_path / 'FLAIR' / 'Preprocessed' / flair_subdir / f'{patient_id}_{slice_num}.png'
521
+ gm_path = base_path / 'GroundTruth' / gt_subdir / 'GM_Masks' / f'{patient_id}_{slice_num}.png'
522
+ brain_path = base_path / 'GroundTruth' / gt_subdir / 'Brain_Masks' / f'{patient_id}_{slice_num}.png'
523
+
524
+ # Optional: zooming factors (only for zoomed preprocessing)
525
+ zoom_factors_path = None
526
+ if preprocessing == 'zoomed':
527
+ zoom_factors_path = base_path / 'FLAIR' / 'Preprocessed' / 'zoomed' / 'images' / f'{patient_id}_zooming_factors.npy'
528
+
529
+ return {
530
+ 'flair': flair_path,
531
+ 'gm_mask': gm_path,
532
+ 'brain_mask': brain_path,
533
+ 'zoom_factors': zoom_factors_path
534
+ }
535
+
536
+ def load_single_slice(self,
537
+ patient_id: str,
538
+ slice_num: int,
539
+ dataset_name: str,
540
+ preprocessing: str,
541
+ class_scenario: str,
542
+ of_z_score: bool = True,
543
+ if_bet: bool = True,
544
+ pre_morph: bool = False) -> Tuple[np.ndarray, np.ndarray]:
545
+ """
546
+ Load a single patient-slice and create paired input
547
+
548
+ Args:
549
+ patient_id: Patient identifier
550
+ slice_num: Slice number
551
+ dataset_name: 'Local_SAI_GM_sp'
552
+ preprocessing: 'standard' or 'zoomed'
553
+ class_scenario: 'binary'
554
+
555
+ Returns:
556
+ Tuple of (paired_input, combined_mask)
557
+ - paired_input: (256, 512, 1) FLAIR + mask concatenated
558
+ - combined_mask: (256, 256) multi-class labels
559
+ """
560
+ # Class number
561
+ num_classes = 1 # int(class_scenario[0]) - 1
562
+
563
+ # Get file paths
564
+ paths = self.get_file_paths(patient_id, slice_num, dataset_name, preprocessing)
565
+
566
+ # Load FLAIR
567
+ flair = load_flair_image(paths['flair'], of_z_score=of_z_score)
568
+
569
+ # Load masks
570
+ gm_mask = load_mask_image(paths['gm_mask'])
571
+ brain_mask = load_mask_image(paths['brain_mask'])
572
+
573
+ # Combine masks
574
+ combined_mask = combine_masks(gm_mask, class_scenario, preprocess=pre_morph)
575
+
576
+ # Create paired input
577
+ paired_input, combined_mask = create_paired_input(flair, combined_mask, brain_mask, num_classes=num_classes, if_bet=if_bet)
578
+
579
+ return paired_input, combined_mask
580
+
581
+ def collect_patient_slices(self,
582
+ patient_list: List[str],
583
+ dataset_name: str,
584
+ preprocessing: str) -> List[Tuple[str, int, str]]:
585
+ """
586
+ Collect all valid slice files for given patients
587
+ FILTERS OUT SLICES WITH ALL EMPTY MASKS
588
+
589
+ Args:
590
+ patient_list: List of patient IDs
591
+ dataset_name: 'Local_SAI_GM_sp'
592
+ preprocessing: 'standard' or 'zoomed'
593
+
594
+ Returns:
595
+ List of tuples (patient_id, slice_num, dataset_name)
596
+ """
597
+ dataset_config = self.config.datasets[dataset_name]
598
+ slice_min, slice_max = dataset_config['slice_range']
599
+
600
+ patient_slices = []
601
+ skipped_empty = 0
602
+
603
+ for patient_id in patient_list:
604
+ # Check which dataset this patient belongs to
605
+ # Try to find patient in current dataset
606
+ for slice_num in range(slice_min, slice_max + 1):
607
+ paths = self.get_file_paths(patient_id, slice_num, dataset_name, preprocessing)
608
+
609
+ # Check if all required files exist
610
+ if (paths['flair'].exists() and
611
+ paths['gm_mask'].exists() and
612
+ paths['brain_mask'].exists()):
613
+
614
+ # VALIDATION: Check if masks are not all empty
615
+ try:
616
+ gm_mask = load_mask_image(paths['gm_mask'])
617
+ brain_mask = load_mask_image(paths['brain_mask'])
618
+
619
+ # Only add if at least one mask has content
620
+ if is_valid_slice(gm_mask):
621
+ patient_slices.append((patient_id, slice_num, dataset_name))
622
+ else:
623
+ skipped_empty += 1
624
+
625
+ except Exception as e:
626
+ print(f"Warning: Could not validate {patient_id}_{slice_num}: {e}")
627
+ skipped_empty += 1
628
+
629
+ if skipped_empty > 0:
630
+ print(f" ⚠️ Skipped {skipped_empty} slices with empty masks")
631
+
632
+ return patient_slices
633
+
634
+ def create_dataset_for_fold(self,
635
+ fold_id: int,
636
+ split: str,
637
+ preprocessing: str,
638
+ class_scenario: str,
639
+ batch_size: int = 1,
640
+ shuffle: bool = True,
641
+ use_z_scored: bool = True,
642
+ bet: bool = False) -> tf.data.Dataset:
643
+ """
644
+ Create TensorFlow dataset for a specific fold and split
645
+
646
+ Args:
647
+ fold_id: Fold number (0-4)
648
+ split: 'train', 'val', or 'test'
649
+ preprocessing: 'standard' or 'zoomed'
650
+ class_scenario: 'binary'
651
+ batch_size: Batch size
652
+ shuffle: Whether to shuffle data
653
+
654
+ Returns:
655
+ tf.data.Dataset yielding (paired_input, combined_mask) batches
656
+ """
657
+ # Load fold assignments
658
+ splitter = PatientStratifiedSplitter(self.config)
659
+ fold_assignments = splitter.load_fold_assignments()
660
+
661
+ # Get patient list for this split
662
+ if split == 'test':
663
+ patient_list = fold_assignments['test_set']['patients']
664
+ else:
665
+ fold_key = f'fold_{fold_id}'
666
+ if split == 'train':
667
+ patient_list = fold_assignments['folds'][fold_key]['train_patients']
668
+ elif split == 'val':
669
+ patient_list = fold_assignments['folds'][fold_key]['val_patients']
670
+ else:
671
+ raise ValueError(f"Unknown split: {split}")
672
+
673
+ print(f"\nCreating dataset for fold {fold_id}, split '{split}'")
674
+ print(f"Patients: {len(patient_list)}")
675
+
676
+ # Collect all patient-slices from both datasets
677
+ all_patient_slices = []
678
+
679
+ for dataset_name in self.config.datasets.keys():
680
+ # Filter patient list to only include patients from this dataset
681
+ # This is done by checking patient ID prefix
682
+ dataset_patients = [p for p in patient_list]
683
+
684
+ patient_slices = self.collect_patient_slices(
685
+ dataset_patients,
686
+ dataset_name,
687
+ preprocessing
688
+ )
689
+ all_patient_slices.extend(patient_slices)
690
+
691
+ print(f"Total slices: {len(all_patient_slices)}")
692
+
693
+ if len(all_patient_slices) == 0:
694
+ raise ValueError(f"No data found for fold {fold_id}, split '{split}'")
695
+
696
+ # Create TensorFlow dataset
697
+ def data_generator():
698
+ """Generator function for tf.data.Dataset"""
699
+ for patient_id, slice_num, dataset_name in all_patient_slices:
700
+ try:
701
+ paired_input, combined_mask = self.load_single_slice(
702
+ patient_id, slice_num, dataset_name,
703
+ preprocessing, class_scenario
704
+ )
705
+ yield paired_input, combined_mask, patient_id, slice_num
706
+ except Exception as e:
707
+ print(f"Error loading {patient_id}_{slice_num}: {e}")
708
+ continue
709
+
710
+ # Create dataset
711
+ dataset = tf.data.Dataset.from_generator(
712
+ data_generator,
713
+ output_signature=(
714
+ tf.TensorSpec(shape=(256, 512, 1), dtype=tf.float32), # concatenated image
715
+ tf.TensorSpec(shape=(256, 256), dtype=tf.uint8), # multi-level mask
716
+ tf.TensorSpec(shape=(), dtype=tf.string), # patient_id
717
+ tf.TensorSpec(shape=(), dtype=tf.int32) # slice_num
718
+ )
719
+ )
720
+
721
+ # ── Cache BEFORE shuffle/batch ──────────────────────────────────────
722
+ # On epoch 1 the generator runs once and all samples are stored
723
+ # in RAM (~1 GB). From epoch 2 onward no disk I/O occurs at all.
724
+ # Placing cache HERE (on unbatched, unshuffled samples) means:
725
+ # • The expensive load/decode/combine step is paid only once.
726
+ # • Shuffle re-randomises the order freshly each epoch (because
727
+ # reshuffle_each_iteration=True is the default).
728
+ # • Batch composition therefore differs every epoch as desired.
729
+ dataset = dataset.cache()
730
+
731
+ # Shuffle if training (acts on the in-RAM cache every epoch)
732
+ if shuffle and split == 'train':
733
+ dataset = dataset.shuffle(
734
+ buffer_size=len(all_patient_slices),
735
+ reshuffle_each_iteration=True # new random order each epoch
736
+ )
737
+
738
+ # Batch and prefetch
739
+ dataset = dataset.batch(batch_size)
740
+ dataset = dataset.prefetch(tf.data.AUTOTUNE)
741
+
742
+ return dataset
743
+
744
+
745
+ ###################### Testing & Validation Functions ######################
746
+
747
+ def test_data_loading():
748
+ """Test data loading functionality"""
749
+ print("\n" + "="*60)
750
+ print("TESTING DATA LOADING")
751
+ print("="*60)
752
+
753
+ config = DataConfig()
754
+
755
+ # Test 1: Create fold assignments
756
+ print("\n[TEST 1] Creating patient stratified splits...")
757
+ splitter = PatientStratifiedSplitter(config)
758
+ fold_assignments = splitter.create_patient_stratified_splits(save=True)
759
+
760
+ # Verify patient separation
761
+ is_valid = splitter.verify_patient_separation(fold_assignments)
762
+
763
+ if not is_valid:
764
+ print("❌ Patient separation verification failed!")
765
+ return False
766
+
767
+ # Test 2: Load a single slice
768
+ print("\n[TEST 2] Loading single slice...")
769
+ loader = P1DataLoader(config)
770
+
771
+ # Get a test patient from fold 0 train set
772
+ test_patient = fold_assignments['folds']['fold_0']['train_patients'][0]
773
+
774
+ # Determine which dataset this patient belongs to
775
+ if test_patient.startswith('1'):
776
+ test_dataset = 'Local_SAI_GM_sp'
777
+ test_slice = 10 # Middle of 8-15 range
778
+ else:
779
+ raise ValueError
780
+
781
+
782
+ try:
783
+ paired_input, combined_mask = loader.load_single_slice(
784
+ test_patient, test_slice, test_dataset,
785
+ 'standard', 'binary'
786
+ )
787
+
788
+ print(f"✅ Loaded slice {test_patient}_{test_slice}")
789
+ print(f" Paired input shape: {paired_input.shape}")
790
+ print(f" Combined mask shape: {combined_mask.shape}")
791
+ print(f" Mask unique values: {np.unique(combined_mask)}")
792
+
793
+ except Exception as e:
794
+ print(f"❌ Failed to load slice: {e}")
795
+ return False
796
+
797
+ # Test 3: Create TensorFlow dataset
798
+ print("\n[TEST 3] Creating TensorFlow dataset...")
799
+ try:
800
+ dataset = loader.create_dataset_for_fold(
801
+ fold_id=0,
802
+ split='train',
803
+ preprocessing='standard',
804
+ class_scenario='binary',
805
+ batch_size=2,
806
+ shuffle=True
807
+ )
808
+
809
+ # Get first batch
810
+ for batch_paired, batch_masks, _, _ in dataset.take(1):
811
+ print(f"✅ Created dataset")
812
+ print(f" Batch paired input shape: {batch_paired.shape}")
813
+ print(f" Batch masks shape: {batch_masks.shape}")
814
+ print(f" Paired input dtype: {batch_paired.dtype}")
815
+ print(f" Masks dtype: {batch_masks.dtype}")
816
+
817
+ except Exception as e:
818
+ print(f"❌ Failed to create dataset: {e}")
819
+ return False
820
+
821
+ print("\n" + "="*60)
822
+ print("✅ ALL TESTS PASSED")
823
+ print("="*60)
824
+
825
+ return True
826
+
827
+
828
+ ###################### Main Execution ######################
829
+
830
+ if __name__ == "__main__":
831
+ # Run tests
832
+ success = test_data_loading()
833
+
834
+ if success:
835
+ print("\n" + "="*60)
836
+ print("DATA LOADER READY FOR USE")
837
+ print("="*60)
838
+ print("\nNext steps:")
839
+ print("1. Verify fold_assignments.json created in data_splits/")
840
+ print("2. Check that all file paths are correct for your system")
841
+ print("3. Proceed to model implementation")
842
+ else:
843
+ print("\n" + "="*60)
844
+ print("❌ DATA LOADER TESTS FAILED")
845
+ print("="*60)
846
+ print("\nPlease fix the issues above before proceeding")
847
+
models/for_GM/model_training_scripts/p1_pix2pix_var5.py ADDED
@@ -0,0 +1,1313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ P1 Article - Specialized Gray Matter (GM) Segmentation with U-Net Models - Journal Paper Implementation
3
+
4
+ Features:
5
+ - Multi-channel Generator output (softmax)
6
+ - Attention-Weighted PatchGAN Discriminator
7
+ - Adaptive hybrid loss (Weighted Categorical Cross-Entropy & Focal Dice)
8
+ - One-hot encoded targets
9
+ - Class weight computation per fold
10
+ - Optimized for severe class imbalance
11
+ """
12
+
13
+ import tensorflow as tf
14
+ import os
15
+ import time
16
+ import numpy as np
17
+ import matplotlib.pyplot as plt
18
+ from pathlib import Path
19
+ from tqdm import tqdm
20
+ import json
21
+
22
+ from unet_model import build_unet_3class
23
+
24
+ # Import data loader
25
+ from p1_data_loader import DataConfig, P1DataLoader
26
+
27
+ # Import utilities from baseline
28
+ from utility_functions import (
29
+ clear_gpu_memory,
30
+ get_gpu_memory_info,
31
+ )
32
+
33
+ # Import class weights utility
34
+ from p1_compute_class_weights import compute_and_save_class_weights, load_class_weights
35
+
36
+
37
+ print("TensorFlow Version:", tf.__version__)
38
+
39
+ ###################### GPU Configuration ######################
40
+
41
+ # Configure GPU memory growth
42
+ physical_devices = tf.config.list_physical_devices('GPU')
43
+ if physical_devices:
44
+ try:
45
+ for device in physical_devices:
46
+ tf.config.experimental.set_memory_growth(device, True)
47
+ print("✅ GPU memory growth enabled")
48
+ print(f" Available GPUs: {len(physical_devices)}")
49
+ except RuntimeError as e:
50
+ print(f"GPU configuration error: {e}")
51
+ else:
52
+ print("⚠️ No GPU detected - training will be slow")
53
+
54
+ """
55
+ GPU Memory Management for Sequential Experiments
56
+ To properly release memory between experiments
57
+ """
58
+
59
+ ###################### Target Preparation ######################
60
+
61
+ def prepare_inputs(paired_input, target_mask, num_classes):
62
+ """
63
+ Prepare inputs for training
64
+
65
+ Args:
66
+ paired_input: (bs, 256, 512, 1) with FLAIR + mask
67
+ target_mask: (bs, 256, 256) with class labels [0, num_classes-1]
68
+ num_classes: number of classes
69
+
70
+ Returns:
71
+ flair_normalized: FLAIR normalized to [-1, 1]
72
+ target_onehot: One-hot encoded mask (bs, 256, 256, num_classes)
73
+ """
74
+ # Extract FLAIR, previously normalized to [-1, 1]
75
+ flair_normalized = paired_input[:, :, :256, :]
76
+
77
+ # One-hot encode target
78
+ target_onehot = tf.one_hot(target_mask, depth=num_classes, dtype=tf.float32)
79
+
80
+ return flair_normalized, target_onehot
81
+
82
+ ###################### Metrics Calculation ######################
83
+
84
+ def compute_classwise_metrics(all_val_true, all_val_pred, num_classes, exclude_class=None):
85
+ """
86
+ Compute class-wise Dice, Precision, and Recall for validation predictions.
87
+
88
+ Args:
89
+ all_val_true: List of one-hot encoded ground truth tensors
90
+ all_val_pred: List of softmax output tensors from generator
91
+ num_classes: Number of classes (2)
92
+ exclude_class: Class to exclude from metric calculation (e.g., 0 for background)
93
+
94
+ Returns:
95
+ Dictionary containing class-wise and mean metrics
96
+ """
97
+ # Concatenate all batches
98
+ y_true_concat = tf.concat(all_val_true, axis=0) # Shape: (N, H, W, num_classes)
99
+ y_pred_concat = tf.concat(all_val_pred, axis=0) # Shape: (N, H, W, num_classes)
100
+
101
+ # Flatten spatial dimensions: (N*H*W, num_classes)
102
+ y_true_flat = tf.reshape(y_true_concat, [-1, num_classes])
103
+ y_pred_flat = tf.reshape(y_pred_concat, [-1, num_classes])
104
+
105
+ # Convert predictions to one-hot (argmax)
106
+ y_pred_classes = tf.argmax(y_pred_flat, axis=-1)
107
+ y_pred_onehot = tf.one_hot(y_pred_classes, depth=num_classes)
108
+
109
+ # Convert to numpy for easier computation
110
+ y_true_np = y_true_flat.numpy()
111
+ y_pred_np = y_pred_onehot.numpy()
112
+
113
+ metrics = {
114
+ 'dice': {},
115
+ 'precision': {},
116
+ 'recall': {}
117
+ }
118
+
119
+ classes_to_evaluate = [c for c in range(num_classes) if c != exclude_class]
120
+
121
+ for class_idx in classes_to_evaluate:
122
+ # Extract binary masks for this class
123
+ true_class = y_true_np[:, class_idx]
124
+ pred_class = y_pred_np[:, class_idx]
125
+
126
+ # True Positives, False Positives, False Negatives
127
+ TP = np.sum((true_class == 1) & (pred_class == 1))
128
+ FP = np.sum((true_class == 0) & (pred_class == 1))
129
+ FN = np.sum((true_class == 1) & (pred_class == 0))
130
+
131
+ # Dice Score: 2*TP / (2*TP + FP + FN)
132
+ dice = (2 * TP) / (2 * TP + FP + FN + 1e-7)
133
+
134
+ # Precision: TP / (TP + FP)
135
+ precision = TP / (TP + FP + 1e-7)
136
+
137
+ # Recall (Sensitivity): TP / (TP + FN)
138
+ recall = TP / (TP + FN + 1e-7)
139
+
140
+ metrics['dice'][f'class_{class_idx}'] = float(dice)
141
+ metrics['precision'][f'class_{class_idx}'] = float(precision)
142
+ metrics['recall'][f'class_{class_idx}'] = float(recall)
143
+
144
+ # Compute mean metrics (excluding the excluded class)
145
+ metrics['dice']['mean'] = np.mean([v for v in metrics['dice'].values()])
146
+ metrics['precision']['mean'] = np.mean([v for v in metrics['precision'].values()])
147
+ metrics['recall']['mean'] = np.mean([v for v in metrics['recall'].values()])
148
+
149
+ return metrics
150
+
151
+ ###################### Experiment Configuration ######################
152
+
153
+ class ExperimentConfig:
154
+ """Configuration for multi-class pix2pix experiment"""
155
+
156
+ def __init__(self,
157
+ variant: int = 1,
158
+ preprocessing: str = 'standard',
159
+ class_scenario: str = 'binary',
160
+ fold_id: int = 0):
161
+
162
+ # Experiment identification
163
+ self.variant = variant
164
+ self.preprocessing = preprocessing # 'standard' or 'zoomed'
165
+ self.class_scenario = class_scenario # 'binary'
166
+ self.fold_id = fold_id
167
+
168
+ # Experiment name
169
+ self.exp_name = f"exp_{variant}_multiclass_{preprocessing}_{class_scenario}_fold{fold_id}"
170
+
171
+ # Number of classes
172
+ self.num_classes = 2 if class_scenario == 'binary' else 2
173
+
174
+ # Training hyperparameters
175
+ self.batch_size = 4
176
+ self.img_width = 256
177
+ self.img_height = 256
178
+ self.epochs = 20
179
+
180
+ # Loss weights
181
+ self.lambda_seg = 50 # seg loss weight
182
+ self.lambda_gan = 1 # GAN loss weight
183
+
184
+ # Adaptive loss parameters
185
+ self.focal_gamma = 0.5 # Focal loss focusing parameter
186
+ self.beta_threshold = 0.25 # Transition at epoch 15/60
187
+ self.beta_smoothness = 0.05 # Transition width
188
+ self.use_focal_alpha = True # Use class weights in focal loss
189
+
190
+ # Optimizer parameters
191
+ self.learning_rate = 2e-4
192
+ self.beta_1 = 0.9
193
+
194
+ # Attention parameters
195
+ self.attention_weight = 2.0 # How much to upweight lesion regions
196
+
197
+ # Paths
198
+ self.results_dir = Path(f"results_fold_{fold_id}_var_{variant}_bet_zscore_gm")
199
+ self.models_dir = self.results_dir / "models" / f"{preprocessing}_{class_scenario}"
200
+ self.figures_dir = self.results_dir / "figures" / f"{preprocessing}_{class_scenario}" / f"fold_{fold_id}"
201
+ self.logs_dir = self.results_dir / "logs" / f"{preprocessing}_{class_scenario}" / f"fold_{fold_id}"
202
+
203
+ # Create directories
204
+ self.models_dir.mkdir(parents=True, exist_ok=True)
205
+ self.figures_dir.mkdir(parents=True, exist_ok=True)
206
+ self.logs_dir.mkdir(parents=True, exist_ok=True)
207
+
208
+ # Checkpoint configuration
209
+ self.checkpoint_dir = self.models_dir / f"fold_{fold_id}"
210
+ self.checkpoint_dir.mkdir(exist_ok=True)
211
+
212
+ # Class weights directory
213
+ self.weights_dir = Path("class_weights_gm")
214
+ self.weights_dir.mkdir(exist_ok=True)
215
+
216
+ # Save configuration
217
+ self.save_config()
218
+
219
+ def save_config(self):
220
+ """Save experiment configuration to JSON"""
221
+ config_dict = {
222
+ 'variant': self.variant,
223
+ 'variant_name': 'Multiclass_AttentionD_AdaptiveLoss',
224
+ 'preprocessing': self.preprocessing,
225
+ 'class_scenario': self.class_scenario,
226
+ 'fold_id': self.fold_id,
227
+ 'num_classes': self.num_classes,
228
+ 'batch_size': self.batch_size,
229
+ 'epochs': self.epochs,
230
+ 'lambda_seg': self.lambda_seg,
231
+ 'lambda_gan': self.lambda_gan,
232
+ 'focal_gamma': self.focal_gamma,
233
+ 'beta_threshold': self.beta_threshold,
234
+ 'beta_smoothness': self.beta_smoothness,
235
+ 'learning_rate': self.learning_rate,
236
+ 'beta_1': self.beta_1,
237
+ 'attention_weight': self.attention_weight,
238
+ 'innovation': 'Phase-transitioning segmentation loss (Weighted CE → Focal Loss)'
239
+ }
240
+
241
+ config_file = self.checkpoint_dir / "config.json"
242
+ with open(config_file, 'w') as f:
243
+ json.dump(config_dict, f, indent=2)
244
+
245
+
246
+ ###################### Model Architecture ######################
247
+
248
+ def downsample(filters, size, apply_norm=True, use_groupnorm=True):
249
+ """
250
+ Downsample block for encoder
251
+
252
+ Args:
253
+ filters: Number of filters
254
+ size: Kernel size
255
+ apply_norm: Whether to apply normalization
256
+ use_groupnorm: If True, use GroupNorm (better for batch_size=1)
257
+ If False, use BatchNorm (original pix2pix)
258
+ """
259
+ initializer = tf.random_normal_initializer(0., 0.02)
260
+
261
+ result = tf.keras.Sequential()
262
+ result.add(
263
+ tf.keras.layers.Conv2D(
264
+ filters, size, strides=2, padding='same',
265
+ kernel_initializer=initializer, use_bias=False
266
+ )
267
+ )
268
+
269
+ if apply_norm:
270
+ if use_groupnorm:
271
+ # ✅ GroupNorm: Independent of batch size, no train/inference mismatch
272
+ # Use 32 groups (standard), or filters//8 if filters < 32
273
+ groups = min(32, max(1, filters // 8))
274
+ result.add(tf.keras.layers.GroupNormalization(groups=groups))
275
+ else:
276
+ # Original BatchNorm (can cause NaN with batch_size=1 at inference)
277
+ result.add(tf.keras.layers.BatchNormalization(momentum=0.99))
278
+
279
+ result.add(tf.keras.layers.LeakyReLU())
280
+
281
+ return result
282
+
283
+
284
+ def build_attention_discriminator(num_classes: int, input_channels: int = 1,
285
+ attention_weight: float = 2.0, use_groupnorm: bool = True):
286
+ """
287
+ Build Attention-Weighted PatchGAN Discriminator
288
+
289
+ Args:
290
+ num_classes: Number of classes in target mask
291
+ input_channels: Number of input channels
292
+ attention_weight: Multiplier for lesion regions (>1.0 upweights lesions)
293
+ use_groupnorm: If True, use GroupNorm instead of BatchNorm
294
+
295
+ Returns:
296
+ Discriminator model
297
+ """
298
+ initializer = tf.random_normal_initializer(0., 0.02)
299
+
300
+ # Input: FLAIR image
301
+ inp = tf.keras.layers.Input(
302
+ shape=[256, 256, input_channels],
303
+ name='input_image'
304
+ )
305
+
306
+ # ✅ Target: Multi-channel one-hot mask
307
+ tar = tf.keras.layers.Input(
308
+ shape=[256, 256, num_classes],
309
+ name='target_mask'
310
+ )
311
+
312
+ # ✅ Compute spatial attention map from target mask
313
+ # attention_map: (bs, 256, 256, 1)
314
+ # Background (class 0) gets weight 1.0, lesions get attention_weight
315
+ class_indices = tf.argmax(tar, axis=-1, output_type=tf.int32) # (bs, 256, 256)
316
+ attention_map = tf.where(
317
+ class_indices == 0,
318
+ tf.ones_like(class_indices, dtype=tf.float32), # Background: weight 1.0
319
+ tf.ones_like(class_indices, dtype=tf.float32) * attention_weight # Lesions: upweighted
320
+ )
321
+ attention_map = tf.expand_dims(attention_map, axis=-1) # (bs, 256, 256, 1)
322
+
323
+ # Concatenate input and target
324
+ x = tf.keras.layers.concatenate([inp, tar]) # (bs, 256, 256, 1+num_classes)
325
+
326
+ # Standard PatchGAN backbone
327
+ down1 = downsample(64, 4, apply_norm=False, use_groupnorm=use_groupnorm)(x) # (bs, 128, 128, 64)
328
+ down2 = downsample(128, 4, use_groupnorm=use_groupnorm)(down1) # (bs, 64, 64, 128)
329
+ down3 = downsample(256, 4, use_groupnorm=use_groupnorm)(down2) # (bs, 32, 32, 256)
330
+
331
+ zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3) # (bs, 34, 34, 256)
332
+ conv = tf.keras.layers.Conv2D(
333
+ 512, 4, strides=1,
334
+ kernel_initializer=initializer,
335
+ use_bias=False
336
+ )(zero_pad1) # (bs, 31, 31, 512)
337
+
338
+ if use_groupnorm:
339
+ batchnorm1 = tf.keras.layers.GroupNormalization(groups=8)(conv)
340
+ else:
341
+ batchnorm1 = tf.keras.layers.BatchNormalization()(conv)
342
+ leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)
343
+
344
+ zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu) # (bs, 33, 33, 512)
345
+
346
+ # Output patch predictions
347
+ patch_output = tf.keras.layers.Conv2D(
348
+ 1, 4, strides=1,
349
+ kernel_initializer=initializer,
350
+ name='patch_predictions'
351
+ )(zero_pad2) # (bs, 30, 30, 1)
352
+
353
+ # ✅ Apply spatial attention to patch predictions
354
+ # Downsample attention map to match patch size (256 -> 30)
355
+ attention_downsampled = tf.keras.layers.AveragePooling2D(
356
+ pool_size=(9, 9), strides=(8, 8), padding='same'
357
+ )(attention_map) # Approximate (bs, 30, 30, 1)
358
+
359
+ # Ensure exact shape match
360
+ attention_resized = tf.image.resize(
361
+ attention_downsampled,
362
+ [tf.shape(patch_output)[1], tf.shape(patch_output)[2]],
363
+ method='bilinear'
364
+ )
365
+
366
+ # Apply attention weighting
367
+ weighted_output = patch_output * attention_resized
368
+
369
+ return tf.keras.Model(
370
+ inputs=[inp, tar],
371
+ outputs=weighted_output,
372
+ name='AttentionDiscriminator'
373
+ )
374
+
375
+
376
+ ###################### Beta Scheduling ######################
377
+
378
+ def smooth_step(x, threshold=0.5, smoothness=0.1):
379
+ """
380
+ Smooth step function for phase transition
381
+
382
+ Creates smooth transition around threshold value using sigmoid.
383
+
384
+ Args:
385
+ x: Current progress (typically epoch / total_epochs)
386
+ threshold: Center point of transition (e.g., 0.5 for epoch 25/50)
387
+ smoothness: Width of transition (smaller = sharper, larger = smoother)
388
+
389
+ Returns:
390
+ Value in [0, 1] representing transition progress
391
+ - x << threshold: returns ≈ 0
392
+ - x ≈ threshold: returns ≈ 0.5
393
+ - x >> threshold: returns ≈ 1
394
+
395
+ Example:
396
+ epoch_progress = 0.3 # Epoch 15/50
397
+ beta = smooth_step(0.3, threshold=0.5, smoothness=0.1)
398
+ # beta ≈ 0.05 (mostly phase 1)
399
+
400
+ epoch_progress = 0.5 # Epoch 25/50
401
+ beta = smooth_step(0.5, threshold=0.5, smoothness=0.1)
402
+ # beta ≈ 0.5 (equal mix)
403
+
404
+ epoch_progress = 0.7 # Epoch 35/50
405
+ beta = smooth_step(0.7, threshold=0.5, smoothness=0.1)
406
+ # beta ≈ 0.95 (mostly phase 2)
407
+ """
408
+ # Sigmoid centered at threshold
409
+ # (x - threshold) / smoothness controls steepness
410
+ return tf.sigmoid((x - threshold) / smoothness)
411
+
412
+
413
+ def compute_beta_schedule(current_epoch, total_epochs,
414
+ threshold=0.5, smoothness=0.1):
415
+ """
416
+ Compute beta value for current epoch
417
+
418
+ Args:
419
+ current_epoch: Current epoch number (0-indexed)
420
+ total_epochs: Total number of epochs
421
+ threshold: Transition center (0.5 = midpoint)
422
+ smoothness: Transition width
423
+
424
+ Returns:
425
+ Beta value in [0, 1]
426
+ """
427
+ epoch_progress = tf.cast(current_epoch, tf.float32) / tf.cast(total_epochs, tf.float32)
428
+ beta = smooth_step(epoch_progress, threshold, smoothness)
429
+ return beta
430
+
431
+
432
+ ###################### Loss Functions ######################
433
+
434
+ def unified_focal_loss(y_true, y_pred, gamma=2.0, alpha=None, exclude_class=None):
435
+ """
436
+ Unified Focal Loss
437
+
438
+ Focal loss down-weights easy examples and focuses on hard examples.
439
+ Particularly effective for class imbalance and boundary regions.
440
+
441
+ Args:
442
+ y_true: Ground truth labels (bs, H, W, num_classes) one-hot encoded
443
+ y_pred: Predicted probabilities (bs, H, W, num_classes) from softmax
444
+ gamma: Focusing parameter (default 2.0)
445
+ - gamma=0: equivalent to cross-entropy
446
+ - gamma>0: down-weights easy examples
447
+ - Higher gamma = more focus on hard examples
448
+ alpha: Per-class balancing weights (num_classes,) - optional, trainable
449
+ - If None, no additional balancing
450
+ - If provided, applies per-class weighting like weighted CE
451
+
452
+ Returns:
453
+ Scalar loss value
454
+
455
+ Formula:
456
+ FL = -α * (1 - p_t)^γ * log(p_t)
457
+ where:
458
+ - p_t is probability of correct class
459
+ - (1 - p_t)^γ is modulating factor (focal term)
460
+ - α is class balancing weight
461
+ """
462
+ # Clip predictions to avoid log(0)
463
+ y_pred = tf.clip_by_value(y_pred, 1e-7, 1.0 - 1e-7)
464
+
465
+ # Probability of correct class at each pixel
466
+ # y_true is one-hot, so this extracts p for the true class
467
+ p_t = tf.reduce_sum(y_true * y_pred, axis=-1)
468
+ # Shape: (bs, H, W)
469
+
470
+ # Focal term: (1 - p_t)^gamma
471
+ # This is small for easy examples (p_t ≈ 1) and large for hard examples (p_t ≈ 0)
472
+ focal_term = tf.pow(1.0 - p_t, gamma)
473
+ # Shape: (bs, H, W)
474
+
475
+ # Cross-entropy term: -log(p_t)
476
+ ce_term = -tf.math.log(p_t)
477
+ # Shape: (bs, H, W)
478
+
479
+ # Focal loss: focal_term * ce_term
480
+ focal_loss = focal_term * ce_term
481
+ # Shape: (bs, H, W)
482
+
483
+ # Optional: Apply alpha balancing (per-class weights)
484
+ if alpha is not None:
485
+ # Get weight for true class at each pixel
486
+ weights_tensor = tf.cast(alpha, dtype=tf.float32)
487
+ weights_tensor = tf.reshape(weights_tensor, [1, 1, 1, -1])
488
+ alpha_map = tf.reduce_sum(y_true * weights_tensor, axis=-1)
489
+ # Shape: (bs, H, W)
490
+
491
+ # Weighted focal
492
+ # Exclude specific class if specified
493
+ if exclude_class is not None:
494
+ class_mask = tf.argmax(y_true, axis=-1) # (bs, 256, 256)
495
+ valid_mask = tf.cast(class_mask != exclude_class, tf.float32)
496
+
497
+ if alpha is not None:
498
+ focal_loss = alpha_map * focal_loss * valid_mask
499
+ else:
500
+ focal_loss = focal_loss * valid_mask
501
+
502
+ return tf.reduce_sum(focal_loss) / (tf.reduce_sum(valid_mask) + 1e-7)
503
+ else:
504
+
505
+ if alpha is not None:
506
+ focal_loss = alpha_map * focal_loss
507
+
508
+ return tf.reduce_mean(focal_loss)
509
+
510
+ def unified_focal_dice_loss(y_true, y_pred, gamma=0.5, delta=0.6, alpha=None, exclude_class=None):
511
+ """
512
+ Unified Focal Loss - Dice-based
513
+
514
+ Combines Dice coefficient with precision-recall focal weighting.
515
+ Best for imbalanced multi-class segmentation with small structures.
516
+
517
+ Args:
518
+ y_true: Ground truth one-hot (bs, H, W, num_classes)
519
+ y_pred: Predicted probabilities (bs, H, W, num_classes)
520
+ gamma: Focusing parameter for Dice component (default 0.5)
521
+ - gamma=0: equivalent to Dice loss
522
+ - gamma>0: focuses on hard examples
523
+ delta: Weight for precision-recall component (0-1, default 0.6)
524
+ - Controls emphasis on boundary regions
525
+ alpha: Per-class weights (num_classes,) - optional
526
+ exclude_class: Class index to exclude from loss
527
+
528
+ Returns:
529
+ Scalar loss value
530
+
531
+ Formula:
532
+ UFL = (1 - Dice)^gamma * (1 - precision * recall)^delta
533
+ Focuses on hard examples and boundary regions
534
+ """
535
+ smooth = 1e-6
536
+ y_pred = tf.clip_by_value(y_pred, 1e-7, 1.0 - 1e-7)
537
+ num_classes = tf.shape(y_pred)[-1]
538
+
539
+ unified_losses = []
540
+
541
+ for class_idx in range(num_classes if isinstance(num_classes, int) else y_pred.shape[-1]):
542
+ # Skip excluded class
543
+ if exclude_class is not None and class_idx == exclude_class:
544
+ continue
545
+
546
+ y_true_class = y_true[..., class_idx]
547
+ y_pred_class = y_pred[..., class_idx]
548
+
549
+ # Flatten for calculations
550
+ y_true_f = tf.reshape(y_true_class, [-1])
551
+ y_pred_f = tf.reshape(y_pred_class, [-1])
552
+
553
+ # True positives, false positives, false negatives
554
+ tp = tf.reduce_sum(y_true_f * y_pred_f)
555
+ fp = tf.reduce_sum((1.0 - y_true_f) * y_pred_f)
556
+ fn = tf.reduce_sum(y_true_f * (1.0 - y_pred_f))
557
+
558
+ # Precision and recall
559
+ precision = (tp + smooth) / (tp + fp + smooth)
560
+ recall = (tp + smooth) / (tp + fn + smooth)
561
+
562
+ # Dice coefficient
563
+ dice = (2.0 * tp + smooth) / (2.0 * tp + fp + fn + smooth)
564
+
565
+ # Unified focal loss: focuses on hard examples and boundary regions
566
+ # (1 - dice)^gamma: focuses on classes with low Dice (hard examples)
567
+ # (1 - precision * recall)^delta: focuses on boundary regions
568
+ unified_loss_class = tf.pow(1.0 - dice, gamma) * tf.pow(1.0 - precision * recall, delta)
569
+
570
+ # Apply class weights
571
+ if alpha is not None:
572
+ unified_loss_class = unified_loss_class * tf.cast(alpha[class_idx], tf.float32)
573
+
574
+ unified_losses.append(unified_loss_class)
575
+
576
+ # Stack and mean across classes (excluding the skipped class)
577
+ total_loss = tf.reduce_mean(tf.stack(unified_losses))
578
+
579
+ return total_loss
580
+
581
+
582
+ def weighted_categorical_crossentropy(y_true, y_pred, class_weights, exclude_class=None):
583
+ """
584
+ Weighted categorical cross-entropy loss
585
+
586
+ Args:
587
+ y_true: (bs, 256, 256, num_classes) one-hot encoded
588
+ y_pred: (bs, 256, 256, num_classes) softmax probabilities
589
+ class_weights: (num_classes,) weight per class
590
+ exclude_class: Optional int, class index to exclude from loss (e.g., 2 for CSF)
591
+
592
+ Returns:
593
+ Scalar loss value
594
+ """
595
+ # Clip predictions to prevent log(0)
596
+ y_pred = tf.clip_by_value(y_pred, 1e-7, 1.0 - 1e-7)
597
+
598
+ # Cross-entropy per pixel: -sum(y_true * log(y_pred))
599
+ ce = -tf.reduce_sum(y_true * tf.math.log(y_pred), axis=-1) # (bs, 256, 256)
600
+
601
+ # Apply class weights
602
+ # class_weights shape: (num_classes,) -> (1, 1, 1, num_classes) for broadcasting
603
+ weights_tensor = tf.cast(class_weights, dtype=tf.float32)
604
+ weights_tensor = tf.reshape(weights_tensor, [1, 1, 1, -1])
605
+
606
+ # Weight map: (bs, 256, 256)
607
+ pixel_weights = tf.reduce_sum(y_true * weights_tensor, axis=-1)
608
+
609
+ # Weighted cross-entropy
610
+ # Exclude specific class if specified
611
+ if exclude_class is not None:
612
+ class_mask = tf.argmax(y_true, axis=-1) # (bs, 256, 256)
613
+ valid_mask = tf.cast(class_mask != exclude_class, tf.float32)
614
+ weighted_ce = ce * pixel_weights * valid_mask
615
+ return tf.reduce_sum(weighted_ce) / (tf.reduce_sum(valid_mask) + 1e-7)
616
+ else:
617
+ weighted_ce = ce * pixel_weights
618
+ return tf.reduce_mean(weighted_ce)
619
+
620
+ # Combined Adaptive Loss #
621
+
622
+ def adaptive_segmentation_loss(y_true, y_pred, class_weights, beta,
623
+ focal_gamma=0.5, use_focal_alpha=True,
624
+ exclude_class=None):
625
+ """
626
+ Adaptive segmentation loss with smooth phase transition
627
+
628
+ Combines weighted cross-entropy (phase 1) and focal loss (phase 2)
629
+ based on beta parameter.
630
+
631
+ Args:
632
+ y_true: Ground truth (bs, H, W, num_classes) one-hot
633
+ y_pred: Predictions (bs, H, W, num_classes) softmax probabilities
634
+ class_weights: Trainable class weights (num_classes,)
635
+ beta: Transition parameter [0, 1]
636
+ - beta=0: pure weighted CE (early training)
637
+ - beta=1: pure focal loss (late training)
638
+ focal_gamma: Focusing parameter for focal loss (default 0.5)
639
+ use_focal_alpha: Whether to use class_weights as focal alpha
640
+
641
+ Returns:
642
+ seg_loss: Combined loss
643
+ wcce_loss: Weighted CE component (for monitoring)
644
+ focal_loss: Focal loss component (for monitoring)
645
+
646
+ Phase Behavior:
647
+ Epochs 1-10: beta ≈ 0 → Weighted CE dominates
648
+ - Learns basic class separation
649
+ - Benefits from explicit class weighting
650
+
651
+ Epochs 10-20: beta transitions 0 → 1
652
+ - Smooth change in loss landscape
653
+ - Gradual shift in training dynamics
654
+
655
+ Epochs 20-60: beta ≈ 1 → Focal loss dominates
656
+ - Focuses on hard examples
657
+ - Refines boundaries and difficult regions
658
+ """
659
+ # Compute Phase 1 loss: Weighted Cross-Entropy
660
+ wcce_loss = weighted_categorical_crossentropy(y_true, y_pred, class_weights, exclude_class=exclude_class)
661
+
662
+ # Compute Phase 2 loss: Focal Loss
663
+ focal_alpha = class_weights if use_focal_alpha else None
664
+ focal_loss = unified_focal_dice_loss(y_true, y_pred,
665
+ gamma=focal_gamma,
666
+ alpha=focal_alpha,
667
+ exclude_class=exclude_class)
668
+
669
+ # Adaptive combination based on beta
670
+ # beta=0: (1-0)*wce + 0*focal = wce (phase 1)
671
+ # beta=1: (1-1)*wce + 1*focal = focal (phase 2)
672
+ # beta=0.5: 0.5*wce + 0.5*focal = equal mix (transition)
673
+ seg_loss = (1.0 - beta) * wcce_loss + beta * focal_loss
674
+
675
+ return seg_loss, wcce_loss, focal_loss
676
+
677
+
678
+ # Binary cross-entropy for GAN loss
679
+ bce_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)
680
+
681
+
682
+ def generator_loss(disc_generated_output, gen_output, target_onehot,
683
+ class_weights, beta, lambda_gan=1, lambda_seg=100,
684
+ focal_gamma=2.0, use_focal_alpha=True):
685
+ """
686
+ Generator loss: GAN loss + Weighted CCE
687
+
688
+ Args:
689
+ disc_generated_output: Discriminator output for generated mask
690
+ gen_output: Generated mask (bs, 256, 256, num_classes) softmax
691
+ target_onehot: Target mask (bs, 256, 256, num_classes) one-hot
692
+ class_weights: (num_classes,) weight per class
693
+ beta: Phase transition parameter [0, 1]
694
+ lambda_gan: Weight for GAN loss (default 1.0)
695
+ lambda_seg: Weight for segmentation loss (default 100.0)
696
+ focal_gamma: Focal loss focusing parameter (default 2.0)
697
+ use_focal_alpha: Whether to use class weights in focal loss
698
+
699
+ Returns:
700
+ total_gen_loss, gan_loss, seg_loss
701
+ """
702
+ # GAN loss: fool the discriminator
703
+ gan_loss = bce_loss(
704
+ tf.ones_like(disc_generated_output),
705
+ disc_generated_output
706
+ )
707
+
708
+ # Weighted categorical cross-entropy
709
+ # seg_loss = weighted_categorical_crossentropy(target_onehot, gen_output, class_weights)
710
+ seg_loss, wcce_loss, focal_loss = adaptive_segmentation_loss(target_onehot, gen_output, class_weights, beta,
711
+ focal_gamma=focal_gamma, use_focal_alpha=True)
712
+
713
+ # Total generator loss
714
+ total_gen_loss = (lambda_gan * gan_loss) + (lambda_seg * seg_loss)
715
+
716
+ return total_gen_loss, gan_loss, seg_loss, wcce_loss, focal_loss
717
+
718
+
719
+ def discriminator_loss(disc_real_output, disc_generated_output):
720
+ """
721
+ Discriminator loss: distinguish real from fake
722
+
723
+ Args:
724
+ disc_real_output: Discriminator output for real mask
725
+ disc_generated_output: Discriminator output for generated mask
726
+
727
+ Returns:
728
+ total_disc_loss
729
+ """
730
+ real_loss = bce_loss(
731
+ tf.ones_like(disc_real_output),
732
+ disc_real_output
733
+ )
734
+
735
+ generated_loss = bce_loss(
736
+ tf.zeros_like(disc_generated_output),
737
+ disc_generated_output
738
+ )
739
+
740
+ total_disc_loss = real_loss + generated_loss
741
+
742
+ return total_disc_loss
743
+
744
+
745
+ ###################### Training Functions ######################
746
+
747
+ @tf.function
748
+ def train_step(input_image, target_onehot, generator, discriminator,
749
+ generator_optimizer, discriminator_optimizer,
750
+ class_weights_np, beta_value,
751
+ lambda_gan, lambda_seg, focal_gamma, use_focal_alpha):
752
+ """
753
+ Single training step
754
+
755
+ Args:
756
+ input_image: Input FLAIR (bs, 256, 256, 1) in [-1, 1]
757
+ target_onehot: Target mask (bs, 256, 256, num_classes) one-hot
758
+ generator, discriminator, optimizers
759
+ class_weights: (num_classes,) weight per class
760
+ beta_value: Current beta for phase transition
761
+ lambda_gan, lambda_seg: Loss weights
762
+ focal_gamma: Focal loss parameter
763
+ use_focal_alpha: Whether to use class weights in focal
764
+
765
+ Returns:
766
+ gen_total_loss, gen_gan_loss, gen_seg_loss, disc_loss
767
+ """
768
+ with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
769
+ # Generate output
770
+ gen_output = generator(input_image, training=True)
771
+
772
+ # Discriminator outputs
773
+ disc_real_output = discriminator(
774
+ [input_image, target_onehot], training=True
775
+ )
776
+ disc_generated_output = discriminator(
777
+ [input_image, gen_output], training=True
778
+ )
779
+
780
+ # Generator loss (adaptive)
781
+ gen_total_loss, gen_gan_loss, gen_seg_loss, gen_wce_loss, gen_focal_loss = \
782
+ generator_loss(
783
+ disc_generated_output, gen_output, target_onehot,
784
+ class_weights_np, beta_value, lambda_gan, lambda_seg,
785
+ focal_gamma, use_focal_alpha
786
+ )
787
+
788
+ # Discriminator loss
789
+ disc_loss = discriminator_loss(disc_real_output, disc_generated_output)
790
+
791
+ # Calculate gradients
792
+ generator_gradients = gen_tape.gradient(
793
+ gen_total_loss, generator.trainable_variables
794
+ )
795
+ discriminator_gradients = disc_tape.gradient(
796
+ disc_loss, discriminator.trainable_variables
797
+ )
798
+
799
+ # Apply gradients
800
+ generator_optimizer.apply_gradients(
801
+ zip(generator_gradients, generator.trainable_variables)
802
+ )
803
+ discriminator_optimizer.apply_gradients(
804
+ zip(discriminator_gradients, discriminator.trainable_variables)
805
+ )
806
+
807
+ # return gen_total_loss, gen_gan_loss, gen_seg_loss, disc_loss
808
+ return (gen_total_loss, gen_gan_loss, gen_seg_loss, gen_wce_loss,
809
+ gen_focal_loss, disc_loss, class_weights_np)
810
+
811
+
812
+ def generate_and_save_images(generator, test_input, test_target,
813
+ epoch, save_path, num_classes):
814
+ """
815
+ Generate predictions and save visualization
816
+
817
+ Args:
818
+ generator: Generator model
819
+ test_input: Test input image (bs, 256, 512, 1)
820
+ test_target: Test target mask (bs, 256, 256)
821
+ epoch: Current epoch number
822
+ save_path: Path to save figure
823
+ num_classes: Number of classes
824
+ """
825
+ for ik in range(test_input.numpy().shape[0]):
826
+ # Extract FLAIR
827
+ flair_normalized = test_input[ik, :, :256, :]
828
+ flair_normalized = tf.expand_dims(flair_normalized, axis=0)
829
+
830
+ # Generate prediction
831
+ prediction_softmax = generator(flair_normalized, training=False)
832
+
833
+ # Convert to class labels
834
+ pred_classes = tf.argmax(prediction_softmax, axis=-1).numpy()
835
+ target_mask = test_target[ik].numpy()
836
+
837
+ # Create figure
838
+ plt.figure(figsize=(20, 5))
839
+
840
+ # Input FLAIR
841
+ plt.subplot(1, 5, 1)
842
+ plt.title('Input FLAIR')
843
+ plt.imshow(flair_normalized[0, :, :, 0], cmap='gray')
844
+ plt.axis('off')
845
+
846
+ # Ground truth
847
+ plt.subplot(1, 5, 2)
848
+ plt.title('Ground Truth')
849
+ plt.imshow(target_mask, cmap='jet', vmin=0, vmax=num_classes-1)
850
+ plt.colorbar()
851
+ plt.axis('off')
852
+
853
+ # Prediction
854
+ plt.subplot(1, 5, 3)
855
+ plt.title('Predicted Classes')
856
+ plt.imshow(pred_classes[0], cmap='jet', vmin=0, vmax=num_classes-1)
857
+ plt.colorbar()
858
+ plt.axis('off')
859
+
860
+ # Class probabilities for most confident prediction
861
+ plt.subplot(1, 5, 4)
862
+ plt.title('Max Probability')
863
+ max_prob = tf.reduce_max(prediction_softmax[0], axis=-1).numpy()
864
+ plt.imshow(max_prob, cmap='viridis', vmin=0, vmax=1)
865
+ plt.colorbar()
866
+ plt.axis('off')
867
+
868
+ # Difference map
869
+ plt.subplot(1, 5, 5)
870
+ plt.title('Error Map (Red=Wrong)')
871
+ error_map = (pred_classes[0] != target_mask).astype(float)
872
+ plt.imshow(error_map, cmap='Reds', vmin=0, vmax=1)
873
+ plt.colorbar()
874
+ plt.axis('off')
875
+
876
+ plt.tight_layout()
877
+ plt.savefig(save_path / f'epoch_{epoch:03d}_{ik+1}.png', dpi=300, bbox_inches='tight')
878
+ plt.close()
879
+
880
+
881
+ ###################### Main Training Function ######################
882
+
883
+ def train_experiment_with_metrics(config: ExperimentConfig):
884
+ """
885
+ Main training function for multi-class pix2pix with attention on discriminator and adaptive loss
886
+
887
+ Args:
888
+ config: ExperimentConfig object
889
+ """
890
+ print("\n" + "="*70)
891
+ print(f"TRAINING EXPERIMENT: {config.exp_name}")
892
+ print("="*70)
893
+ print(f"Variant: {config.variant} (Baseline + AttentionD + Adaptive Loss)")
894
+ print(f"Preprocessing: {config.preprocessing}")
895
+ print(f"Class scenario: {config.class_scenario} ({config.num_classes} classes)")
896
+ print(f"Fold: {config.fold_id}")
897
+ print(f"Epochs: {config.epochs}")
898
+ print(f"Batch size: {config.batch_size}")
899
+ print(f"Loss weights: λ_SEG={config.lambda_seg}, λ_GAN={config.lambda_gan}")
900
+ print(f"Focal gamma: {config.focal_gamma}")
901
+ print(f"Attention weight: {config.attention_weight}")
902
+ print("="*70 + "\n")
903
+
904
+ # Check initial GPU memory
905
+ get_gpu_memory_info()
906
+
907
+ # Initialize data loader
908
+ data_config = DataConfig()
909
+ data_loader = P1DataLoader(data_config)
910
+
911
+ # Load datasets
912
+ print("Loading training data...")
913
+ train_dataset = data_loader.create_dataset_for_fold(
914
+ fold_id=config.fold_id,
915
+ split='train',
916
+ preprocessing=config.preprocessing,
917
+ class_scenario=config.class_scenario,
918
+ batch_size=config.batch_size,
919
+ shuffle=True
920
+ )
921
+
922
+ print("Loading validation data...")
923
+ val_dataset = data_loader.create_dataset_for_fold(
924
+ fold_id=config.fold_id,
925
+ split='val',
926
+ preprocessing=config.preprocessing,
927
+ class_scenario=config.class_scenario,
928
+ batch_size=config.batch_size,
929
+ shuffle=False
930
+ )
931
+
932
+ # Get dataset sizes
933
+ # Note: from_generator pipelines always report cardinality as INFINITE (-1)
934
+ # even with .cache(), so we derive the batch count from the slice list instead.
935
+ # We iterate once here; this also warms the in-memory cache so epoch 1 is fast.
936
+ print("Warming dataset cache (first pass over data — subsequent epochs use RAM)...")
937
+ train_size = sum(1 for _ in train_dataset)
938
+ val_size = sum(1 for _ in val_dataset)
939
+ # ⚠️ Do NOT rebuild the datasets here — that would create new generators and
940
+ # throw away the cache we just populated.
941
+
942
+ print(f"Training samples (batches): {train_size}")
943
+ print(f"Validation samples (batches): {val_size}\n")
944
+
945
+ # Compute or load class weights
946
+ print("Computing class weights from training data...")
947
+ try:
948
+ class_weights = load_class_weights(
949
+ config.fold_id, config.class_scenario,
950
+ config.preprocessing, config.weights_dir
951
+ )
952
+ print("✅ Loaded pre-computed class weights")
953
+ except FileNotFoundError:
954
+ print("Computing class weights (this may take a few minutes)...")
955
+ results = compute_and_save_class_weights(
956
+ config.fold_id, config.class_scenario,
957
+ config.preprocessing, str(config.weights_dir)
958
+ )
959
+ class_weights = np.array(results['class_weights'], dtype=np.float32)
960
+
961
+ print(f"Class weights: {class_weights}")
962
+
963
+ # Build models
964
+ print("\n🏗️ Building models...")
965
+ generator = build_unet_3class(input_shape=(256, 256, 1), num_classes=config.num_classes)
966
+ discriminator = build_attention_discriminator(
967
+ config.num_classes,
968
+ input_channels=1,
969
+ attention_weight=config.attention_weight,
970
+ use_groupnorm=True # ✅ Consistent with generator
971
+ )
972
+
973
+ print(f"Generator parameters: {generator.count_params():,}")
974
+ print(f"Discriminator parameters: {discriminator.count_params():,}\n")
975
+
976
+ # Optimizers
977
+ generator_optimizer = tf.keras.optimizers.legacy.Adam(
978
+ config.learning_rate, beta_1=config.beta_1
979
+ )
980
+ discriminator_optimizer = tf.keras.optimizers.legacy.Adam(
981
+ config.learning_rate, beta_1=config.beta_1
982
+ )
983
+
984
+ # Initialize optimizer variables
985
+ # CRITICAL: Build optimizer variables by calling them once with dummy data
986
+ # This prevents the "tf.function only supports singleton tf.Variables" error
987
+ print("Initializing optimizer variables...")
988
+ dummy_input = tf.zeros((1, 256, 256, 1))
989
+ dummy_target = tf.zeros((1, 256, 256, config.num_classes))
990
+
991
+ with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
992
+ gen_output = generator(dummy_input, training=True)
993
+ disc_output = discriminator([dummy_input, dummy_target], training=True)
994
+ # Dummy losses
995
+ dummy_gen_loss = tf.reduce_mean(gen_output)
996
+ dummy_disc_loss = tf.reduce_mean(disc_output)
997
+
998
+ # Apply dummy gradients to build optimizer variables
999
+ # Don't include class_weights since they're not trainable
1000
+ gen_grads = gen_tape.gradient(dummy_gen_loss, generator.trainable_variables)
1001
+ disc_grads = disc_tape.gradient(dummy_disc_loss, discriminator.trainable_variables)
1002
+
1003
+ generator_optimizer.apply_gradients(zip(gen_grads, generator.trainable_variables))
1004
+ discriminator_optimizer.apply_gradients(zip(disc_grads, discriminator.trainable_variables))
1005
+ print("✅ Optimizer variables initialized\n")
1006
+
1007
+ # Checkpoint
1008
+ checkpoint = tf.train.Checkpoint(
1009
+ generator_optimizer=generator_optimizer,
1010
+ discriminator_optimizer=discriminator_optimizer,
1011
+ generator=generator,
1012
+ discriminator=discriminator
1013
+ )
1014
+
1015
+ checkpoint_prefix = config.checkpoint_dir / "ckpt"
1016
+ manager = tf.train.CheckpointManager(
1017
+ checkpoint, config.checkpoint_dir, max_to_keep=1
1018
+ )
1019
+
1020
+ if manager.latest_checkpoint:
1021
+ checkpoint.restore(manager.latest_checkpoint)
1022
+ print(f"✅ Restored from checkpoint: {manager.latest_checkpoint}\n")
1023
+ else:
1024
+ print("Starting training from scratch\n")
1025
+
1026
+ # Load pretrained models:
1027
+ generator_weights_path = f"{config.checkpoint_dir}/best_dice_generator.h5"
1028
+ if os.path.isfile(generator_weights_path):
1029
+ generator.load_weights(generator_weights_path)
1030
+
1031
+ discriminator_weights_path = f"{config.checkpoint_dir}/best_dice_discriminator.h5"
1032
+ if os.path.isfile(discriminator_weights_path):
1033
+ discriminator.load_weights(discriminator_weights_path)
1034
+
1035
+ # Get example for visualization
1036
+ skip_n = 1 # min(100 // config.batch_size, val_size - 1)
1037
+ example_paired, example_target, _, _ = next(iter(val_dataset.skip(skip_n).take(20)))
1038
+
1039
+ print("Initializing metrics computer...")
1040
+ if config.num_classes == 2:
1041
+ class_names = ['Background', 'Specialized_GM']
1042
+ else:
1043
+ raise FileNotFoundError
1044
+
1045
+ # Training history
1046
+ history = {
1047
+ 'gen_total_loss': [],
1048
+ 'gen_gan_loss': [],
1049
+ 'gen_seg_loss': [],
1050
+ 'gen_wce_loss': [],
1051
+ 'gen_focal_loss': [],
1052
+ 'disc_loss': [],
1053
+ 'val_loss': [],
1054
+ 'beta_value': [],
1055
+ 'val_metrics': []
1056
+ }
1057
+
1058
+ # Training loop
1059
+ best_val_loss = float('inf')
1060
+ best_val_dice = 0.0
1061
+ exclude_class = None # Exclude class !
1062
+
1063
+ try:
1064
+ for epoch in range(config.epochs):
1065
+ start_time = time.time()
1066
+
1067
+ # Compute beta for this epoch
1068
+ beta_value = compute_beta_schedule(
1069
+ epoch, config.epochs,
1070
+ config.beta_threshold, config.beta_smoothness
1071
+ )
1072
+
1073
+ # Training metrics
1074
+ epoch_gen_total_loss = []
1075
+ epoch_gen_gan_loss = []
1076
+ epoch_gen_seg_loss = []
1077
+ epoch_gen_wce_loss = []
1078
+ epoch_gen_focal_loss = []
1079
+ epoch_disc_loss = []
1080
+
1081
+ # Training loop
1082
+
1083
+ # Update learning rate based on epoch
1084
+ new_lr_1 = config.learning_rate * ((1-(7/8)*beta_value)) # Exponential decay based on beta (based on switching on focal loss)
1085
+ new_lr_2 = config.learning_rate * ((1-(1-0.5e-2)*(epoch / config.epochs))) # Steadily decay from 2e-4 to 1e-6
1086
+ new_lr = min(new_lr_1, new_lr_2)
1087
+ generator_optimizer.learning_rate.assign(new_lr)
1088
+ discriminator_optimizer.learning_rate.assign(new_lr)
1089
+
1090
+ lambda_GAN = config.lambda_gan*(1-beta_value.numpy()).astype(np.float64)
1091
+ print(f"\nEpoch {epoch+1}/{config.epochs} (β={beta_value.numpy():.4f}) (λ_GAN={lambda_GAN:.4f}) (lr={new_lr:.6f})")
1092
+ train_bar = tqdm(train_dataset, total=train_size, desc="Training")
1093
+
1094
+ for paired_input, target_mask, patient_id_tensor, slice_num_tensor in train_bar:
1095
+
1096
+ patient_id = patient_id_tensor.numpy()[0].decode('utf-8') # batch dim + bytes→str
1097
+ slice_num = int(slice_num_tensor.numpy()[0])
1098
+
1099
+ # ✅ Prepare inputs: normalize FLAIR + one-hot encode target
1100
+ flair_normalized, target_onehot = prepare_inputs(
1101
+ paired_input, target_mask, config.num_classes
1102
+ )
1103
+
1104
+ # Train step
1105
+ gen_total, gen_gan, gen_seg, gen_wce, gen_focal, disc, cw = train_step(
1106
+ flair_normalized, target_onehot,
1107
+ generator, discriminator,
1108
+ generator_optimizer, discriminator_optimizer,
1109
+ class_weights, beta_value,
1110
+ config.lambda_gan, config.lambda_seg,
1111
+ config.focal_gamma, config.use_focal_alpha
1112
+ )
1113
+
1114
+ epoch_gen_total_loss.append(gen_total.numpy())
1115
+ epoch_gen_gan_loss.append(gen_gan.numpy())
1116
+ epoch_gen_seg_loss.append(gen_seg.numpy())
1117
+ epoch_gen_wce_loss.append(gen_wce.numpy())
1118
+ epoch_gen_focal_loss.append(gen_focal.numpy())
1119
+ epoch_disc_loss.append(disc.numpy())
1120
+
1121
+ # Update progress bar
1122
+ train_bar.set_postfix({
1123
+ 'G_loss': f"{gen_total.numpy():.4f}",
1124
+ 'D_loss': f"{disc.numpy():.4f}",
1125
+ 'SEG': f"{gen_seg.numpy():.4f}"
1126
+ })
1127
+
1128
+ # Calculate epoch averages
1129
+ avg_gen_total = np.mean(epoch_gen_total_loss)
1130
+ avg_gen_gan = np.mean(epoch_gen_gan_loss)
1131
+ avg_gen_seg = np.mean(epoch_gen_seg_loss)
1132
+ avg_gen_wce = np.mean(epoch_gen_wce_loss)
1133
+ avg_gen_focal = np.mean(epoch_gen_focal_loss)
1134
+ avg_disc = np.mean(epoch_disc_loss)
1135
+
1136
+ history['gen_total_loss'].append(avg_gen_total)
1137
+ history['gen_gan_loss'].append(avg_gen_gan)
1138
+ history['gen_seg_loss'].append(avg_gen_seg)
1139
+ history['gen_wce_loss'].append(avg_gen_wce)
1140
+ history['gen_focal_loss'].append(avg_gen_focal)
1141
+ history['disc_loss'].append(avg_disc)
1142
+ history['beta_value'].append(float(beta_value.numpy()))
1143
+
1144
+ # Validation
1145
+ val_losses = []
1146
+ all_val_true = []
1147
+ all_val_pred = []
1148
+
1149
+ for val_paired, val_target, patient_id_tensor, slice_num_tensor in val_dataset:
1150
+ try:
1151
+
1152
+ patient_id = patient_id_tensor.numpy()[0].decode('utf-8') # batch dim + bytes→str
1153
+ slice_num = int(slice_num_tensor.numpy()[0])
1154
+
1155
+ val_flair_norm, val_target_onehot = prepare_inputs(
1156
+ val_paired, val_target, config.num_classes
1157
+ )
1158
+
1159
+ val_pred = generator(val_flair_norm, training=False) # ✅ Now safe!
1160
+
1161
+ val_seg_loss, _, _ = adaptive_segmentation_loss(
1162
+ val_target_onehot, val_pred, class_weights,
1163
+ beta_value, focal_gamma=config.focal_gamma, exclude_class=exclude_class
1164
+ )
1165
+
1166
+ # Store true and prediction values for final metrics calculation
1167
+ all_val_true.append(val_target_onehot)
1168
+ all_val_pred.append(val_pred)
1169
+
1170
+ if not tf.math.is_nan(val_seg_loss):
1171
+ val_losses.append(val_seg_loss.numpy())
1172
+ except:
1173
+ continue
1174
+
1175
+
1176
+ if len(val_losses) > 0:
1177
+ avg_val_loss = np.mean(val_losses)
1178
+ history['val_loss'].append(avg_val_loss)
1179
+
1180
+ # Compute class-wise metrics
1181
+ val_metrics = compute_classwise_metrics(
1182
+ all_val_true, all_val_pred,
1183
+ config.num_classes#, exclude_class=exclude_class
1184
+ )
1185
+ history['val_metrics'].append(val_metrics)
1186
+
1187
+ # Print validation results
1188
+ epoch_time = time.time() - start_time
1189
+ print(f"\n{'='*70}")
1190
+ print(f"Epoch {epoch+1}/{config.epochs} Summary (Time: {epoch_time:.2f}s)")
1191
+ print(f"{'='*70}")
1192
+ print(f"Training Losses:")
1193
+ print(f" Generator Total: {avg_gen_total:.4f} | GAN: {avg_gen_gan:.4f} | SEG: {avg_gen_seg:.4f}")
1194
+ print(f" WCE: {avg_gen_wce:.4f} | Focal: {avg_gen_focal:.4f} | Discriminator: {avg_disc:.4f}")
1195
+ print(f"\nValidation Loss: {avg_val_loss:.4f}")
1196
+ print(f"\nClass-wise Dice Scores:")
1197
+ for class_name, dice_val in val_metrics['dice'].items():
1198
+ if class_name != 'mean':
1199
+ print(f" {class_name}: {dice_val:.4f}")
1200
+ if class_name == f"class_{config.num_classes -1}":
1201
+ gm_val_dice = dice_val
1202
+ print(f" Mean Dice: {val_metrics['dice']['mean']:.4f}")
1203
+ print(f"\nClass-wise Precision:")
1204
+ for class_name, prec_val in val_metrics['precision'].items():
1205
+ if class_name != 'mean':
1206
+ print(f" {class_name}: {prec_val:.4f}")
1207
+ print(f" Mean Precision: {val_metrics['precision']['mean']:.4f}")
1208
+ print(f"\nClass-wise Recall:")
1209
+ for class_name, rec_val in val_metrics['recall'].items():
1210
+ if class_name != 'mean':
1211
+ print(f" {class_name}: {rec_val:.4f}")
1212
+ print(f" Mean Recall: {val_metrics['recall']['mean']:.4f}")
1213
+ print(f"{'='*70}\n")
1214
+
1215
+ # Save best model based on validation loss
1216
+ overal_val_performance = 0.9 * gm_val_dice + 0.1 * (1-10*avg_val_loss)
1217
+ if overal_val_performance > best_val_dice and beta_value.numpy() > 0.9:
1218
+ best_val_dice = overal_val_performance
1219
+ generator.save_weights(f"{config.checkpoint_dir}/best_dice_generator.h5")
1220
+ discriminator.save_weights(f"{config.checkpoint_dir}/best_dice_discriminator.h5")
1221
+ print(f"✓ Best model saved (performance: {best_val_dice:.4f})")
1222
+ else:
1223
+ print("Warning: No valid validation batches")
1224
+ history['val_loss'].append(float('nan'))
1225
+ history['val_metrics'].append({})
1226
+
1227
+ # Print epoch summary
1228
+ epoch_time = time.time() - start_time
1229
+ print(f"Epoch {epoch+1} Summary:")
1230
+ print(f" Gen Total Loss: {avg_gen_total:.4f}")
1231
+ print(f" Gen GAN Loss: {avg_gen_gan:.4f}")
1232
+ print(f" Gen Seg Loss: {avg_gen_seg:.4f}")
1233
+ print(f" - WCE component: {avg_gen_wce:.4f}")
1234
+ print(f" - Focal component: {avg_gen_focal:.4f}")
1235
+ print(f" Disc Loss: {avg_disc:.4f}")
1236
+ print(f" Val Loss: {avg_val_loss:.4f}")
1237
+ print(f" Beta: {beta_value.numpy():.4f}")
1238
+ print(f" Time: {epoch_time:.2f}s")
1239
+
1240
+ # Save checkpoint
1241
+ if (epoch + 1) % 5 == 0 and False:
1242
+ manager.save()
1243
+ print(f" 💾 Saved checkpoint")
1244
+
1245
+ # Generate sample images
1246
+ if (epoch + 1) % 5 == 0 or epoch == 0 or True:
1247
+ generate_and_save_images(
1248
+ generator, example_paired, example_target,
1249
+ epoch + 1, config.figures_dir, config.num_classes
1250
+ )
1251
+ print(f" 📊 Saved visualization")
1252
+
1253
+ # # Save final model
1254
+ # final_model_path = config.checkpoint_dir / "final_model.h5"
1255
+ # generator.save(final_model_path)
1256
+ # print(f"\n✅ Training complete! Final model saved to {final_model_path}")
1257
+
1258
+ # Save history
1259
+ history_serializable = {
1260
+ key: [float(val) if isinstance(val, (int, float, np.number)) else val
1261
+ for val in values]
1262
+ for key, values in history.items()
1263
+ }
1264
+
1265
+ history_file = config.checkpoint_dir / "history.json"
1266
+ with open(history_file, 'w') as f:
1267
+ json.dump(history_serializable, f, indent=2)
1268
+
1269
+ return history, history_file
1270
+
1271
+ finally:
1272
+ # CRITICAL: Always cleanup, even if training fails
1273
+ # This runs whether training succeeds or fails
1274
+ print("\n🧹 Cleaning up resources...")
1275
+
1276
+ # Delete models explicitly to break references
1277
+ try:
1278
+ del generator
1279
+ del discriminator
1280
+ del generator_optimizer
1281
+ del discriminator_optimizer
1282
+ del checkpoint
1283
+ del manager
1284
+ del train_dataset
1285
+ del val_dataset
1286
+ # class_weights don't need deletion (they're constants, not variables)
1287
+ print("✅ Deleted model objects")
1288
+ except Exception as e:
1289
+ print(f"⚠️ Error deleting objects: {e}")
1290
+
1291
+ # Clear GPU memory
1292
+ clear_gpu_memory()
1293
+
1294
+ # Check final GPU memory
1295
+ get_gpu_memory_info()
1296
+
1297
+
1298
+ ###################### Main Execution ######################
1299
+
1300
+ if __name__ == "__main__":
1301
+ # Example: Train multi-class model for 4-class, standard preprocessing, fold 0
1302
+ config = ExperimentConfig(
1303
+ variant=1,
1304
+ preprocessing='standard',
1305
+ class_scenario='binary',
1306
+ fold_id=0
1307
+ )
1308
+
1309
+ history, history_path = train_experiment_with_metrics(config)
1310
+
1311
+ print("\n" + "="*70)
1312
+ print("EXPERIMENT COMPLETE")
1313
+ print("="*70)
models/for_GM/model_training_scripts/p1_predict_new_data_gm.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ P1 Article - Prediction Script for New Data (No Ground Truth)
3
+
4
+ Predicts specialized Gray Matter segmentation masks for new HC/MS cohort patients.
5
+
6
+ Outputs per patient:
7
+ - {patient_id}_gm_mask.nii.gz → binary gm mask (class 1)
8
+
9
+ Developer:
10
+ Mahdi Bashiri Bawil
11
+ """
12
+
13
+ import tensorflow as tf
14
+ import os
15
+ import numpy as np
16
+ from pathlib import Path
17
+ from tqdm import tqdm
18
+ import nibabel as nib
19
+ import argparse
20
+
21
+ print("TensorFlow Version:", tf.__version__)
22
+
23
+
24
+ ###################### GPU Configuration ######################
25
+
26
+ physical_devices = tf.config.list_physical_devices('GPU')
27
+ if physical_devices:
28
+ try:
29
+ for device in physical_devices:
30
+ tf.config.experimental.set_memory_growth(device, True)
31
+ print(f"✅ GPU memory growth enabled ({len(physical_devices)} GPU(s) found)")
32
+ except RuntimeError as e:
33
+ print(f"GPU configuration error: {e}")
34
+ else:
35
+ print("⚠️ No GPU detected – inference will run on CPU")
36
+
37
+
38
+ ###################### Configuration ######################
39
+
40
+ class PredictConfig:
41
+ """
42
+ All settings for the new-data prediction pipeline.
43
+ Edit the values in __init__ or pass overrides via the CLI at the bottom.
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ # ── Model settings ──────────────────────────────────────────────────
49
+ variant: int = 1,
50
+ preprocessing: str = "standard",
51
+ class_scenario: str = "binary",
52
+ architecture_name: str = "unet",
53
+ model_name: str = "best_dice_model.h5",
54
+ fold_id: int = 0,
55
+
56
+ # ── Slice range (1-based, inclusive) ────────────────────────────────
57
+ # Only slices within [slice_start, slice_end] are fed to the model.
58
+ # All other slices receive empty (zero) masks.
59
+ slice_start: int = 1,
60
+ slice_end: int = 20,
61
+
62
+ # ── Data root ───────────────────────────────────────────────────────
63
+ data_root: str = "/mnt/d/TEMP_P4",
64
+
65
+ # ── Post-processing ─────────────────────────────────────────────────
66
+ apply_postprocess: bool = False,
67
+ min_object_size: int = 5,
68
+ closing_kernel_size: int = 2,
69
+ ):
70
+ # Experiment
71
+ self.variant = variant
72
+ self.fold_id = fold_id
73
+ self.preprocessing = preprocessing
74
+ self.class_scenario = class_scenario
75
+ self.architecture_name = architecture_name
76
+ self.model_name = model_name
77
+
78
+ # Classes
79
+ self.num_classes = 2
80
+ self.class_names = ["Background", "Specialized_GM"]
81
+
82
+ # Image dimensions (must match training)
83
+ self.img_width = 256
84
+ self.img_height = 256
85
+
86
+ # Slice range (1-based, inclusive)
87
+ self.slice_start = slice_start
88
+ self.slice_end = slice_end
89
+
90
+ # Post-processing
91
+ self.apply_postprocess = apply_postprocess
92
+ print(f'\n \t apply_postprocess: {apply_postprocess} \n')
93
+ self.min_object_size = min_object_size
94
+ self.closing_kernel_size = closing_kernel_size
95
+
96
+ # Data root
97
+ self.data_root = Path(data_root)
98
+
99
+ # Cohort sub-directories (relative to data_root)
100
+ self.cohorts = {
101
+ "HC": self.data_root / "HC_COHORT_PREP_prepared" / "FLAIR_Preprocessed",
102
+ "MS": self.data_root / "MS_COHORT_PREP_prepared" / "FLAIR_Preprocessed",
103
+ }
104
+
105
+ # Model path
106
+ self.results_dir = Path(
107
+ f"results_fold_{fold_id}_var_{variant}_bet_zscore_gm" # adjust if you use a single fold
108
+ )
109
+ self.models_dir = self.results_dir / "models" / f"{preprocessing}_{class_scenario}"
110
+
111
+ # ── Print summary ────────────────────────────────────────────────────
112
+ print(f"\n{'='*70}")
113
+ print("PREDICTION CONFIGURATION (New Data)")
114
+ print(f"{'='*70}")
115
+ print(f" Variant : {self.variant}")
116
+ print(f" Fold : {self.fold_id}")
117
+ print(f" Preprocessing : {self.preprocessing}")
118
+ print(f" Class scenario : {self.class_scenario} ({self.num_classes} classes)")
119
+ print(f" Architecture : {self.architecture_name}")
120
+ print(f" Model file : {self.model_name}")
121
+ print(f" Slice range : {self.slice_start} – {self.slice_end} (1-based)")
122
+ print(f" Post-processing : {self.apply_postprocess}")
123
+ print(f" Data root : {self.data_root}")
124
+ print(f"{'='*70}\n")
125
+
126
+
127
+ ###################### Utility Helpers ######################
128
+
129
+ def load_nifti(path: Path):
130
+ """Load a NIfTI file and return (numpy_array, nib_image)."""
131
+ img = nib.load(str(path))
132
+ return img.get_fdata(dtype=np.float32), img
133
+
134
+
135
+ def save_binary_nifti(mask: np.ndarray, save_path: Path, reference_img):
136
+ """
137
+ Save a binary 3-D mask as a NIfTI file.
138
+
139
+ Args:
140
+ mask : (H, W, S) or (S, H, W) boolean/uint8 array
141
+ save_path : destination path (*.nii.gz)
142
+ reference_img: nibabel image whose affine/header are reused
143
+ """
144
+ save_path.parent.mkdir(parents=True, exist_ok=True)
145
+ nifti_out = nib.Nifti1Image(
146
+ mask.astype(np.uint8),
147
+ reference_img.affine,
148
+ reference_img.header,
149
+ )
150
+ nib.save(nifti_out, str(save_path))
151
+
152
+
153
+ def preprocess_slice(slice_2d: np.ndarray, target_h: int = 256, target_w: int = 256) -> np.ndarray:
154
+ """
155
+ Resize a 2-D slice to (target_h, target_w) if necessary and
156
+ return a float32 array with shape (1, H, W, 1) ready for the model.
157
+
158
+ The data files are assumed to be already normalised to [0, 1] and
159
+ z-score normalised (as stated in the task description), so no
160
+ additional intensity normalisation is applied here.
161
+ """
162
+ import cv2 # lightweight resize; falls back to zoom if cv2 unavailable
163
+
164
+ h, w = slice_2d.shape
165
+ if h != target_h or w != target_w:
166
+ slice_2d = cv2.resize(
167
+ slice_2d, (target_w, target_h), interpolation=cv2.INTER_LINEAR
168
+ )
169
+
170
+ # shape → (1, H, W, 1)
171
+ return slice_2d[np.newaxis, :, :, np.newaxis].astype(np.float32)
172
+
173
+
174
+ def post_process_pred(pred_classes: np.ndarray, num_classes: int = 2,
175
+ min_object_size: int = 5, closing_kernel_size: int = 2) -> np.ndarray:
176
+ """
177
+ Morphological post-processing for a single 2-D prediction slice.
178
+ Identical to the function used during training inference.
179
+
180
+ Pipeline (per foreground class):
181
+ 1. Extract binary mask from the label map.
182
+ 2. binary_closing – fill small holes / bridge tiny gaps.
183
+ 3. remove_small_objects – discard isolated noise specks.
184
+ 4. Reconstruct integer label map.
185
+ """
186
+ from skimage.morphology import remove_small_objects, binary_closing, disk
187
+
188
+ kernel = disk(closing_kernel_size)
189
+
190
+ def clean(mask):
191
+ if not mask.any():
192
+ return mask
193
+ mask = binary_closing(mask, kernel)
194
+ mask = remove_small_objects(mask, min_size=min_object_size)
195
+ return mask
196
+
197
+ gm_mask = (pred_classes == 1)
198
+
199
+ gm_mask = clean(gm_mask)
200
+
201
+ post_pred = np.zeros_like(pred_classes)
202
+ post_pred[gm_mask] = 1
203
+
204
+ return post_pred
205
+
206
+
207
+ ###################### Model Loading ######################
208
+
209
+ def load_model(config: PredictConfig, fold_id: int):
210
+ """
211
+ Build the model architecture and load weights for the given fold.
212
+
213
+ Returns the loaded generator (keras Model).
214
+ """
215
+ if config.architecture_name == "unet":
216
+ from unet_model import build_unet_3class as build_fn
217
+ elif config.architecture_name == "attnunet":
218
+ from attn_unet_model import build_attention_unet_3class as build_fn
219
+ elif config.architecture_name == "dlv3unet":
220
+ from dlv3_unet_model_GN import build_deeplabv3_unet_3class as build_fn
221
+ elif config.architecture_name == "transunet":
222
+ from trans_unet_model import build_trans_unet_3class as build_fn
223
+ else:
224
+ raise ValueError(f"Unknown architecture: {config.architecture_name}")
225
+
226
+ model_path = (
227
+ config.models_dir
228
+ / f"fold_{fold_id}"
229
+ / config.model_name
230
+ )
231
+
232
+ if not model_path.exists():
233
+ raise FileNotFoundError(f"Model not found: {model_path}")
234
+
235
+ generator = build_fn(
236
+ input_shape=(config.img_height, config.img_width, 1),
237
+ num_classes=config.num_classes,
238
+ )
239
+ generator.load_weights(str(model_path))
240
+ print(f" ✅ Fold {fold_id} model loaded from: {model_path}")
241
+ return generator
242
+
243
+
244
+ ###################### Per-Patient Prediction ######################
245
+
246
+ def predict_patient(
247
+ patient_id: str,
248
+ flair_path: Path,
249
+ brain_mask_path: Path,
250
+ models: list, # list of keras generators (one per fold)
251
+ config: PredictConfig,
252
+ gm_out_dir: Path,
253
+ ):
254
+ """
255
+ Run inference for a single patient and save Specialized_GM masks.
256
+
257
+ Steps:
258
+ 1. Load FLAIR volume and brain mask.
259
+ 2. Apply brain mask (multiply) → brain-extracted volume.
260
+ 3. For each slice in [slice_start, slice_end]:
261
+ a. Resize to 256×256.
262
+ b. Run through all fold models and average softmax outputs.
263
+ c. argmax → class label.
264
+ d. Optional post-processing.
265
+ 4. Slices outside the range → empty (zero) predictions.
266
+ 5. Save: main prediction, Specialized_GM binary mask.
267
+ """
268
+ # ── Load data ────────────────────────────────────────────────────────────
269
+ flair_data, flair_img = load_nifti(flair_path) # (H, W, S)
270
+ brain_mask, _ = load_nifti(brain_mask_path) # (H, W, S) binary
271
+
272
+ # Brain extraction: zero out non-brain voxels
273
+ brain_mask_bool = brain_mask > 0
274
+ flair_brain = np.copy(flair_data)
275
+ flair_brain[~brain_mask_bool] = np.min(flair_data)
276
+
277
+ # flair_brain = flair_data * brain_mask # (H, W, S)
278
+
279
+ num_slices = flair_brain.shape[2]
280
+
281
+ # Convert to 0-based slice indices for the active range
282
+ # Input: slice_start / slice_end are 1-based (as stated in the task).
283
+ active_start = config.slice_start - 1 # inclusive, 0-based
284
+ active_end = config.slice_end - 1 # inclusive, 0-based
285
+
286
+ # Clamp to actual volume depth
287
+ active_start = max(0, active_start)
288
+ active_end = min(num_slices - 1, active_end)
289
+
290
+ # Initialise output volumes (H, W, S) – same spatial shape as the input
291
+ H, W = flair_brain.shape[0], flair_brain.shape[1]
292
+ pred_volume = np.zeros((H, W, num_slices), dtype=np.uint8) # main prediction
293
+ gm_volume = np.zeros((H, W, num_slices), dtype=np.uint8) # binary Specialized_GM
294
+
295
+ # ── Inference loop ───────────────────────────────────────────────────────
296
+ for s in range(num_slices):
297
+
298
+ if s < active_start or s > active_end:
299
+ # Outside desired range: leave masks empty
300
+ continue
301
+
302
+ slice_2d = flair_brain[:, :, s] # (H, W)
303
+ model_input = preprocess_slice( # (1, 256, 256, 1)
304
+ slice_2d, config.img_height, config.img_width
305
+ )
306
+
307
+ # Ensemble: average softmax probabilities across all fold models
308
+ softmax_sum = np.zeros(
309
+ (1, config.img_height, config.img_width, config.num_classes),
310
+ dtype=np.float32,
311
+ )
312
+ for gen in models:
313
+ softmax_sum += gen(model_input, training=False).numpy()
314
+
315
+ softmax_avg = softmax_sum / len(models) # (1, H, W, C)
316
+ pred_slice = np.argmax(softmax_avg, axis=-1)[0] # (H, W)
317
+
318
+ # Optional post-processing
319
+ if config.apply_postprocess:
320
+ pred_slice = post_process_pred(
321
+ pred_slice,
322
+ num_classes=config.num_classes,
323
+ min_object_size=config.min_object_size,
324
+ closing_kernel_size=config.closing_kernel_size,
325
+ )
326
+
327
+ # If model output is 256×256 but original slice is different size, resize back
328
+ if pred_slice.shape != (H, W):
329
+ import cv2
330
+ pred_slice = cv2.resize(
331
+ pred_slice.astype(np.float32), (W, H),
332
+ interpolation=cv2.INTER_NEAREST,
333
+ ).astype(np.uint8)
334
+
335
+ pred_volume[:, :, s] = pred_slice
336
+
337
+ # Binary masks
338
+ # Specialized_GM = class 1 in 2-class
339
+ gm_volume[:, :, s] = (pred_slice == 1).astype(np.uint8)
340
+
341
+ # ── Save outputs ─────────────────────────────────────────────────────────
342
+ gm_path = gm_out_dir / f"{patient_id}_gm_mask.nii.gz"
343
+
344
+ save_binary_nifti(gm_volume, gm_path, flair_img)
345
+
346
+ n_gm = int(gm_volume.sum())
347
+ print(
348
+ f" Patient {patient_id}: "
349
+ f"GM voxels = {n_gm:6d}"
350
+ )
351
+ print(f" → {gm_path}")
352
+
353
+
354
+ ###################### Main Prediction Pipeline ######################
355
+
356
+ def run_prediction(config: PredictConfig, fold_ids: list = None):
357
+ """
358
+ Full prediction pipeline for all patients in HC and MS cohorts.
359
+
360
+ Args:
361
+ config : PredictConfig object.
362
+ fold_ids : List of fold IDs to ensemble (e.g. [0, 1, 2, 3]).
363
+ If None, defaults to [0, 1, 2, 3].
364
+ """
365
+ if fold_ids is None:
366
+ fold_ids = [0]
367
+
368
+ # ── Load all fold models ─────────────────────────────────────────────────
369
+ print(f"\nLoading models for folds: {fold_ids}")
370
+ models = []
371
+ for fold_id in fold_ids:
372
+ gen = load_model(config, fold_id)
373
+ models.append(gen)
374
+ print(f"✅ {len(models)} model(s) loaded\n")
375
+
376
+ # ── Iterate over cohorts ─────────────────────────────────────────────────
377
+ for cohort_name, cohort_flair_dir in config.cohorts.items():
378
+ files_dir = cohort_flair_dir / "files"
379
+ brain_masks_dir = cohort_flair_dir / "Brain_Masks"
380
+ gm_out_dir = cohort_flair_dir / "GM_Masks"
381
+
382
+ # Create output directories
383
+ gm_out_dir.mkdir(parents=True, exist_ok=True)
384
+
385
+ # Discover patients from the files directory
386
+ flair_files = sorted(files_dir.glob("*.nii.gz"))
387
+ if not flair_files:
388
+ print(f"⚠️ No FLAIR files found in {files_dir} – skipping {cohort_name} cohort")
389
+ continue
390
+
391
+ print(f"\n{'='*70}")
392
+ print(f"COHORT: {cohort_name} ({len(flair_files)} patients found)")
393
+ print(f" FLAIR dir : {files_dir}")
394
+ print(f" Brain masks dir : {brain_masks_dir}")
395
+ print(f" Output GM dir : {gm_out_dir}")
396
+ print(f"{'='*70}")
397
+
398
+ skipped = 0
399
+ for flair_path in tqdm(flair_files, desc=f"{cohort_name} patients"):
400
+ # Extract 6-digit patient ID from filename
401
+ patient_id = flair_path.stem.replace(".nii", "") # handles double .nii.gz
402
+
403
+ brain_mask_path = brain_masks_dir / f"{patient_id}_brain_mask.nii.gz"
404
+
405
+ if not brain_mask_path.exists(): # or patient_id != '110214':
406
+ print(
407
+ f"\n ⚠️ Brain mask not found for patient {patient_id} "
408
+ f"(expected: {brain_mask_path}) – skipping"
409
+ )
410
+ skipped += 1
411
+ continue
412
+
413
+ try:
414
+ predict_patient(
415
+ patient_id=patient_id,
416
+ flair_path=flair_path,
417
+ brain_mask_path=brain_mask_path,
418
+ models=models,
419
+ config=config,
420
+ gm_out_dir=gm_out_dir,
421
+ )
422
+ except Exception as exc:
423
+ print(f"\n ❌ Error processing patient {patient_id}: {exc}")
424
+ skipped += 1
425
+
426
+ done = len(flair_files) - skipped
427
+ print(
428
+ f"\n ✅ {cohort_name} cohort done: {done} predicted, {skipped} skipped\n"
429
+ )
430
+
431
+ print("\n" + "="*70)
432
+ print("ALL COHORTS PROCESSED")
433
+ print("="*70)
434
+
435
+
436
+ ###################### Entry Point ######################
437
+
438
+ if __name__ == "__main__":
439
+
440
+ parser = argparse.ArgumentParser(
441
+ description="P4 – Predict Specialized_GM for new HC / MS cohort data"
442
+ )
443
+ parser.add_argument("--variant", type=int, default=1)
444
+ parser.add_argument("--preprocessing", type=str, default="standard")
445
+ parser.add_argument("--class_scenario", type=str, default="binary",
446
+ choices=["binary"])
447
+ parser.add_argument("--architecture", type=str, default="unet",
448
+ choices=["unet", "attnunet", "dlv3unet", "transunet"])
449
+ parser.add_argument("--model_name", type=str, default="best_dice_generator.h5")
450
+ parser.add_argument("--folds", type=int, nargs="+", default=[0],
451
+ help="Fold IDs to ensemble (e.g. --folds 0 1 2 3)")
452
+ parser.add_argument("--slice_start", type=int, default=1,
453
+ help="First slice to predict (1-based, inclusive)")
454
+ parser.add_argument("--slice_end", type=int, default=20,
455
+ help="Last slice to predict (1-based, inclusive)")
456
+ parser.add_argument("--data_root", type=str, default="/mnt/d/TEMP_P4")
457
+ parser.add_argument("--no_postprocess", action="store_false",
458
+ help="Disable morphological post-processing")
459
+ parser.add_argument("--min_object_size", type=int, default=5)
460
+ parser.add_argument("--closing_size", type=int, default=2)
461
+ args = parser.parse_args()
462
+
463
+ config = PredictConfig(
464
+ variant=args.variant,
465
+ preprocessing=args.preprocessing,
466
+ class_scenario=args.class_scenario,
467
+ architecture_name=args.architecture,
468
+ model_name=args.model_name,
469
+ slice_start=args.slice_start,
470
+ slice_end=args.slice_end,
471
+ data_root=args.data_root,
472
+ apply_postprocess=not args.no_postprocess,
473
+ min_object_size=args.min_object_size,
474
+ closing_kernel_size=args.closing_size,
475
+ )
476
+
477
+ run_prediction(config, fold_ids=args.folds)
models/for_GM/model_training_scripts/unet_model.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###################### Libraries ######################
2
+ # Deep Learning
3
+ import keras
4
+ from keras.models import Model
5
+ from keras.layers import Input, Conv2D, MaxPooling2D, Conv2DTranspose, concatenate
6
+
7
+
8
+ def build_unet_3class(input_shape=(256, 256, 1), num_classes=3):
9
+ """Enhanced U-Net architecture with batch normalization and dropout"""
10
+ inputs = Input(input_shape)
11
+
12
+ # Encoder with batch normalization
13
+ c1 = Conv2D(64, 3, activation='relu', padding='same')(inputs)
14
+ # c1 = keras.layers.BatchNormalization()(c1)
15
+ c1 = Conv2D(64, 3, activation='relu', padding='same')(c1)
16
+ # c1 = keras.layers.BatchNormalization()(c1)
17
+ p1 = MaxPooling2D()(c1)
18
+ p1 = keras.layers.Dropout(0.1)(p1)
19
+
20
+ c2 = Conv2D(128, 3, activation='relu', padding='same')(p1)
21
+ # c2 = keras.layers.BatchNormalization()(c2)
22
+ c2 = Conv2D(128, 3, activation='relu', padding='same')(c2)
23
+ # c2 = keras.layers.BatchNormalization()(c2)
24
+ p2 = MaxPooling2D()(c2)
25
+ p2 = keras.layers.Dropout(0.1)(p2)
26
+
27
+ c3 = Conv2D(256, 3, activation='relu', padding='same')(p2)
28
+ # c3 = keras.layers.BatchNormalization()(c3)
29
+ c3 = Conv2D(256, 3, activation='relu', padding='same')(c3)
30
+ # c3 = keras.layers.BatchNormalization()(c3)
31
+ p3 = MaxPooling2D()(c3)
32
+ p3 = keras.layers.Dropout(0.2)(p3)
33
+
34
+ c4 = Conv2D(512, 3, activation='relu', padding='same')(p3)
35
+ # c4 = keras.layers.BatchNormalization()(c4)
36
+ c4 = Conv2D(512, 3, activation='relu', padding='same')(c4)
37
+ # c4 = keras.layers.BatchNormalization()(c4)
38
+ p4 = MaxPooling2D()(c4)
39
+ p4 = keras.layers.Dropout(0.2)(p4)
40
+
41
+ # Bottleneck
42
+ c5 = Conv2D(1024, 3, activation='relu', padding='same')(p4)
43
+ # c5 = keras.layers.BatchNormalization()(c5)
44
+ c5 = Conv2D(1024, 3, activation='relu', padding='same')(c5)
45
+ # c5 = keras.layers.BatchNormalization()(c5)
46
+ c5 = keras.layers.Dropout(0.3)(c5)
47
+
48
+ # Decoder
49
+ u6 = Conv2DTranspose(512, 2, strides=2, padding='same')(c5)
50
+ u6 = concatenate([u6, c4])
51
+ u6 = keras.layers.Dropout(0.2)(u6)
52
+ c6 = Conv2D(512, 3, activation='relu', padding='same')(u6)
53
+ # c6 = keras.layers.BatchNormalization()(c6)
54
+ c6 = Conv2D(512, 3, activation='relu', padding='same')(c6)
55
+ # c6 = keras.layers.BatchNormalization()(c6)
56
+
57
+ u7 = Conv2DTranspose(256, 2, strides=2, padding='same')(c6)
58
+ u7 = concatenate([u7, c3])
59
+ u7 = keras.layers.Dropout(0.2)(u7)
60
+ c7 = Conv2D(256, 3, activation='relu', padding='same')(u7)
61
+ # c7 = keras.layers.BatchNormalization()(c7)
62
+ c7 = Conv2D(256, 3, activation='relu', padding='same')(c7)
63
+ # c7 = keras.layers.BatchNormalization()(c7)
64
+
65
+ u8 = Conv2DTranspose(128, 2, strides=2, padding='same')(c7)
66
+ u8 = concatenate([u8, c2])
67
+ u8 = keras.layers.Dropout(0.1)(u8)
68
+ c8 = Conv2D(128, 3, activation='relu', padding='same')(u8)
69
+ # c8 = keras.layers.BatchNormalization()(c8)
70
+ c8 = Conv2D(128, 3, activation='relu', padding='same')(c8)
71
+ # c8 = keras.layers.BatchNormalization()(c8)
72
+
73
+ u9 = Conv2DTranspose(64, 2, strides=2, padding='same')(c8)
74
+ u9 = concatenate([u9, c1])
75
+ u9 = keras.layers.Dropout(0.1)(u9)
76
+ c9 = Conv2D(64, 3, activation='relu', padding='same')(u9)
77
+ # c9 = keras.layers.BatchNormalization()(c9)
78
+ c9 = Conv2D(64, 3, activation='relu', padding='same')(c9)
79
+ # c9 = keras.layers.BatchNormalization()(c9)
80
+
81
+ # Output layer
82
+ if num_classes == 1:
83
+ outputs = Conv2D(1, 1, activation='sigmoid')(c9)
84
+ else:
85
+ outputs = Conv2D(num_classes, 1, activation='softmax')(c9)
86
+
87
+ return Model(inputs, outputs, name='UNet')
models/for_GM/model_training_scripts/utility_functions.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ P1 Article - Utility Functions
3
+
4
+
5
+ Developer:
6
+ "Mahdi Bashiri Bawil"
7
+ """
8
+
9
+ import gc
10
+ import tensorflow as tf
11
+ from tensorflow.keras import backend as K
12
+
13
+
14
+ print("TensorFlow Version:", tf.__version__)
15
+
16
+ ###################### GPU Configuration ######################
17
+
18
+ # Configure GPU memory growth
19
+ physical_devices = tf.config.list_physical_devices('GPU')
20
+ if physical_devices:
21
+ try:
22
+ for device in physical_devices:
23
+ tf.config.experimental.set_memory_growth(device, True)
24
+ print("✅ GPU memory growth enabled")
25
+ print(f" Available GPUs: {len(physical_devices)}")
26
+ except RuntimeError as e:
27
+ print(f"GPU configuration error: {e}")
28
+ else:
29
+ print("⚠️ No GPU detected - training will be slow")
30
+
31
+ """
32
+ GPU Memory Management for Sequential Experiments
33
+ To properly release memory between experiments
34
+ """
35
+
36
+
37
+ def clear_gpu_memory():
38
+ """
39
+ Comprehensive GPU memory cleanup between experiments
40
+ Call this after each experiment completes
41
+ """
42
+ print("\n" + "="*70)
43
+ print("CLEANING UP GPU MEMORY")
44
+ print("="*70)
45
+
46
+ # Clear Keras session
47
+ K.clear_session()
48
+ print("✅ Cleared Keras session")
49
+
50
+ # Force garbage collection multiple times
51
+ for _ in range(3):
52
+ gc.collect()
53
+ print("✅ Ran garbage collection (3 passes)")
54
+
55
+ # Reset TensorFlow graphs
56
+ tf.compat.v1.reset_default_graph()
57
+ print("✅ Reset default graph")
58
+
59
+ # Additional cleanup for TF 2.x
60
+ try:
61
+ # Clear any cached tensors
62
+ tf.config.experimental.reset_memory_stats('GPU:0')
63
+ print("✅ Reset GPU memory stats")
64
+ except:
65
+ pass
66
+
67
+ # CRITICAL: Reset GPU memory allocator
68
+ # This forces TensorFlow to release memory back to the system
69
+ try:
70
+ physical_devices = tf.config.list_physical_devices('GPU')
71
+ if physical_devices:
72
+ # Disable and re-enable memory growth to flush allocator
73
+ for device in physical_devices:
74
+ tf.config.experimental.set_memory_growth(device, False)
75
+ tf.config.experimental.set_memory_growth(device, True)
76
+ print("✅ Reset memory growth (flushed allocator)")
77
+ except Exception as e:
78
+ print(f"⚠️ Could not reset memory growth: {e}")
79
+
80
+ print("="*70 + "\n")
81
+
82
+
83
+ def get_gpu_memory_info():
84
+ """
85
+ Print current GPU memory usage
86
+ Useful for monitoring memory leaks
87
+ """
88
+ try:
89
+ gpu_devices = tf.config.list_physical_devices('GPU')
90
+ if gpu_devices:
91
+ for device in gpu_devices:
92
+ details = tf.config.experimental.get_memory_info(device.name.replace('/physical_device:', ''))
93
+ current_mb = details['current'] / 1024**2
94
+ peak_mb = details['peak'] / 1024**2
95
+ print(f"GPU Memory - Current: {current_mb:.1f} MB, Peak: {peak_mb:.1f} MB")
96
+ except Exception as e:
97
+ print(f"Could not get GPU memory info: {e}")
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_001_1.png ADDED

Git LFS Details

  • SHA256: eadc42599cba73222f8cf13373545bfa9c96b5e0d1d6e3b99c285a8d725d5439
  • Pointer size: 131 Bytes
  • Size of remote file: 181 kB
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_001_2.png ADDED

Git LFS Details

  • SHA256: 542b3f350aacc50c8ab191629db82409dedc8c7738944f308eb385b3f7b1b11d
  • Pointer size: 131 Bytes
  • Size of remote file: 179 kB
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_002_1.png ADDED

Git LFS Details

  • SHA256: a9d97fa90722f2f4d1090745e7728f4397eb01b68128d45f433169a49b1f9f6d
  • Pointer size: 131 Bytes
  • Size of remote file: 179 kB
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_002_2.png ADDED

Git LFS Details

  • SHA256: f63730bc2934051ed6c59603ddcbecb26b1fea755f61ac070af206398b94f329
  • Pointer size: 131 Bytes
  • Size of remote file: 176 kB
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_003_1.png ADDED

Git LFS Details

  • SHA256: 529ec0da4f46d09a2c998a368c46a588a8ede2e4ee666a2cccef5267f2af7456
  • Pointer size: 131 Bytes
  • Size of remote file: 178 kB
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_003_2.png ADDED

Git LFS Details

  • SHA256: 0792796c6a778f512dd7404330cc089d4f360c73fef71169b4dd91ff90c76e76
  • Pointer size: 131 Bytes
  • Size of remote file: 175 kB
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_004_1.png ADDED

Git LFS Details

  • SHA256: dc591c692a290a25a26bcce181ebad4af374b6e9dd214daa90b282f70ca25556
  • Pointer size: 131 Bytes
  • Size of remote file: 179 kB
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_004_2.png ADDED

Git LFS Details

  • SHA256: 3ed42b8ad14b395da091c771a89e12fd1a85a1e5bf437fcb35f76b347fd14471
  • Pointer size: 131 Bytes
  • Size of remote file: 176 kB
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_005_1.png ADDED

Git LFS Details

  • SHA256: 74ef1abee3e2eb492d4770740d0c9306cbd1fb36d98493840846c7d1f525a865
  • Pointer size: 131 Bytes
  • Size of remote file: 178 kB
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_005_2.png ADDED

Git LFS Details

  • SHA256: fd47c91f71c11787f5349dc3a339484a789270f8a8069ac32e135c85890e0e4f
  • Pointer size: 131 Bytes
  • Size of remote file: 175 kB
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_006_1.png ADDED

Git LFS Details

  • SHA256: 8b73521c6078398fe1b9f7dd5390a959202ed2e81aaa7dc53470629a559a18a6
  • Pointer size: 131 Bytes
  • Size of remote file: 187 kB
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_006_2.png ADDED

Git LFS Details

  • SHA256: 508c67867c4f87b53f9f598319e336d85aa1e0ef9bea564840c508ba8800b5bd
  • Pointer size: 131 Bytes
  • Size of remote file: 185 kB
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_007_1.png ADDED

Git LFS Details

  • SHA256: 75ffd97185ffacf187911c6c00b41f93dab4f0d166a8acac207e7a3d04a50755
  • Pointer size: 131 Bytes
  • Size of remote file: 186 kB
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_007_2.png ADDED

Git LFS Details

  • SHA256: 7e670dca56771127c013c0fd3973547f6109618f0328434352e3e294ed26d590
  • Pointer size: 131 Bytes
  • Size of remote file: 184 kB
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_008_1.png ADDED

Git LFS Details

  • SHA256: 0c825e24aa77aea7836285836c7d99a98e24e61abb6a110549741ca53ede5272
  • Pointer size: 131 Bytes
  • Size of remote file: 183 kB
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_008_2.png ADDED

Git LFS Details

  • SHA256: 573c61fbfa0a3302097b9df6e5525f1418dbbd16062922fb7424184b6c501675
  • Pointer size: 131 Bytes
  • Size of remote file: 181 kB
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_009_1.png ADDED

Git LFS Details

  • SHA256: 148837a15ed963f1731557f51728701b2d4e4892d0de145bb6bb50f5e93dc2ba
  • Pointer size: 131 Bytes
  • Size of remote file: 183 kB
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_009_2.png ADDED

Git LFS Details

  • SHA256: fc2a65766cb895bea7edd8e893268ff6a1b5792958ae55086863c15a4b6397d7
  • Pointer size: 131 Bytes
  • Size of remote file: 182 kB
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_010_1.png ADDED

Git LFS Details

  • SHA256: b87e8926095966c31de5b83a98abfe6da2431fc5d0f11e65262bd5553ed83c99
  • Pointer size: 131 Bytes
  • Size of remote file: 184 kB
models/for_GM/results_fold_0_var_1_bet_zscore_gm/figures/standard_binary/fold_0/epoch_010_2.png ADDED

Git LFS Details

  • SHA256: 7fe5a5e0371eab0ef03e90d2af4bf84c0b2c429d11d54100b2d7858fcbe17470
  • Pointer size: 131 Bytes
  • Size of remote file: 182 kB
models/for_GM/results_fold_0_var_1_bet_zscore_gm/models/standard_binary/fold_0/best_dice_discriminator.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3d37fd879442368aaba7813f2549a1dda9c2376be2e6ec000d6352b2901e7207
3
+ size 11107072
models/for_GM/results_fold_0_var_1_bet_zscore_gm/models/standard_binary/fold_0/best_dice_generator.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b28647f12d0639cb2f15a03e6f9334c45dfff35b25cf90df8504dd10d931598
3
+ size 124213136
models/for_GM/results_fold_0_var_1_bet_zscore_gm/models/standard_binary/fold_0/config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "variant": 1,
3
+ "variant_name": "Multiclass_AttentionD_AdaptiveLoss",
4
+ "preprocessing": "standard",
5
+ "class_scenario": "binary",
6
+ "fold_id": 0,
7
+ "num_classes": 2,
8
+ "batch_size": 4,
9
+ "epochs": 20,
10
+ "lambda_seg": 50,
11
+ "lambda_gan": 1,
12
+ "focal_gamma": 0.5,
13
+ "beta_threshold": 0.25,
14
+ "beta_smoothness": 0.05,
15
+ "learning_rate": 0.0002,
16
+ "beta_1": 0.9,
17
+ "attention_weight": 2.0,
18
+ "innovation": "Phase-transitioning segmentation loss (Weighted CE \u2192 Focal Loss)"
19
+ }
models/for_GM/results_fold_0_var_1_bet_zscore_gm/models/standard_binary/fold_0/download_models.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Visit our Hugging Face link for downloading the trained models.
models/for_GM/results_fold_0_var_1_bet_zscore_gm/models/standard_binary/fold_0/history.json ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "gen_total_loss": [
3
+ 16.160808563232422,
4
+ 15.84605598449707,
5
+ 15.596329689025879,
6
+ 15.1962308883667,
7
+ 15.236580848693848
8
+ ],
9
+ "gen_gan_loss": [
10
+ 1.4460713863372803,
11
+ 1.3656171560287476,
12
+ 1.3531222343444824,
13
+ 1.1605579853057861,
14
+ 1.2858413457870483
15
+ ],
16
+ "gen_seg_loss": [
17
+ 0.2942947447299957,
18
+ 0.28960874676704407,
19
+ 0.2848641574382782,
20
+ 0.2807134687900543,
21
+ 0.27901479601860046
22
+ ],
23
+ "gen_wce_loss": [
24
+ 0.427320659160614,
25
+ 0.5918631553649902,
26
+ 0.6427512764930725,
27
+ 0.6484421491622925,
28
+ 0.6594918966293335
29
+ ],
30
+ "gen_focal_loss": [
31
+ 0.2933984398841858,
32
+ 0.28957146406173706,
33
+ 0.28486335277557373,
34
+ 0.2807134687900543,
35
+ 0.27901479601860046
36
+ ],
37
+ "disc_loss": [
38
+ 0.9733297228813171,
39
+ 0.9525837898254395,
40
+ 0.9791364669799805,
41
+ 1.0745328664779663,
42
+ 0.9908205270767212
43
+ ],
44
+ "val_loss": [
45
+ 0.2753802239894867,
46
+ 0.28040140867233276,
47
+ 0.28456518054008484,
48
+ 0.28689467906951904,
49
+ 0.2823749780654907
50
+ ],
51
+ "beta_value": [
52
+ 0.9933071732521057,
53
+ 0.9998766183853149,
54
+ 0.9999977350234985,
55
+ 1.0,
56
+ 1.0
57
+ ],
58
+ "val_metrics": [
59
+ {
60
+ "dice": {
61
+ "class_0": 0.9691289499147928,
62
+ "class_1": 0.752369057690533,
63
+ "mean": 0.8607490038026628
64
+ },
65
+ "precision": {
66
+ "class_0": 0.9797553013827063,
67
+ "class_1": 0.6921517437594977,
68
+ "mean": 0.835953522571102
69
+ },
70
+ "recall": {
71
+ "class_0": 0.9587306304286759,
72
+ "class_1": 0.8240626410757951,
73
+ "mean": 0.8913966357522355
74
+ }
75
+ },
76
+ {
77
+ "dice": {
78
+ "class_0": 0.9697302587837198,
79
+ "class_1": 0.7479798920589907,
80
+ "mean": 0.8588550754213553
81
+ },
82
+ "precision": {
83
+ "class_0": 0.976320417478653,
84
+ "class_1": 0.708180715854344,
85
+ "mean": 0.8422505666664984
86
+ },
87
+ "recall": {
88
+ "class_0": 0.9632284706744223,
89
+ "class_1": 0.7925187994877733,
90
+ "mean": 0.8778736350810978
91
+ }
92
+ },
93
+ {
94
+ "dice": {
95
+ "class_0": 0.9700212734386713,
96
+ "class_1": 0.7450693229965066,
97
+ "mean": 0.857545298217589
98
+ },
99
+ "precision": {
100
+ "class_0": 0.9743977498147104,
101
+ "class_1": 0.7176589260174784,
102
+ "mean": 0.8460283379160944
103
+ },
104
+ "recall": {
105
+ "class_0": 0.9656839348841134,
106
+ "class_1": 0.7746567035718136,
107
+ "mean": 0.8701703192279635
108
+ }
109
+ },
110
+ {
111
+ "dice": {
112
+ "class_0": 0.9707925785575585,
113
+ "class_1": 0.7425312591148605,
114
+ "mean": 0.8566619188362095
115
+ },
116
+ "precision": {
117
+ "class_0": 0.9715280350271916,
118
+ "class_1": 0.7376090600933536,
119
+ "mean": 0.8545685475602727
120
+ },
121
+ "recall": {
122
+ "class_0": 0.9700582347414827,
123
+ "class_1": 0.7475195929191307,
124
+ "mean": 0.8587889138303066
125
+ }
126
+ },
127
+ {
128
+ "dice": {
129
+ "class_0": 0.9707689737822686,
130
+ "class_1": 0.7466375133242413,
131
+ "mean": 0.8587032435532549
132
+ },
133
+ "precision": {
134
+ "class_0": 0.9731953621432797,
135
+ "class_1": 0.730843840931104,
136
+ "mean": 0.8520196015371919
137
+ },
138
+ "recall": {
139
+ "class_0": 0.9683546543618544,
140
+ "class_1": 0.7631288712746304,
141
+ "mean": 0.8657417628182424
142
+ }
143
+ }
144
+ ]
145
+ }