Commit
·
82949fb
1
Parent(s):
7bfdced
Skip the background segmentation mask
Browse filesCompute the HD95 metric
Allow the user to save or not the generated NIfTI images
- COMET/Evaluate_network.py +20 -16
COMET/Evaluate_network.py
CHANGED
|
@@ -63,6 +63,7 @@ if __name__ == '__main__':
|
|
| 63 |
parser.add_argument('--erase', type=bool, help='Erase the content of the output folder', default=False)
|
| 64 |
parser.add_argument('--outdirname', type=str, default='Evaluate')
|
| 65 |
parser.add_argument('--fullres', action='store_true', default=False)
|
|
|
|
| 66 |
args = parser.parse_args()
|
| 67 |
if args.model is not None:
|
| 68 |
assert '.h5' in args.model[0], 'No checkpoint file provided, use -d/--dir instead'
|
|
@@ -95,10 +96,10 @@ if __name__ == '__main__':
|
|
| 95 |
|
| 96 |
with h5py.File(list_test_files[0], 'r') as f:
|
| 97 |
image_input_shape = image_output_shape = list(f['fix_image'][:].shape[:-1])
|
| 98 |
-
nb_labels = f['fix_segmentations'][:].shape[-1]
|
| 99 |
|
| 100 |
# Header of the metrics csv file
|
| 101 |
-
csv_header = ['File', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'DICE_MACRO', 'HD', 'Time', 'TRE', 'No_missing_lbls', 'Missing_lbls']
|
| 102 |
|
| 103 |
# TF stuff
|
| 104 |
config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
|
|
@@ -163,17 +164,17 @@ if __name__ == '__main__':
|
|
| 163 |
print('DESTINATION FOLDER FULL RESOLUTION: ', output_folder_fr)
|
| 164 |
|
| 165 |
try:
|
| 166 |
-
network = tf.keras.models.load_model(MODEL_FILE, {'VxmDenseSemiSupervisedSeg': vxm.networks.VxmDenseSemiSupervisedSeg,
|
| 167 |
'VxmDense': vxm.networks.VxmDense,
|
| 168 |
'AdamAccumulated': AdamAccumulated,
|
| 169 |
'loss': loss_fncs,
|
| 170 |
'metric': metric_fncs},
|
| 171 |
compile=False)
|
| 172 |
except ValueError as e:
|
| 173 |
-
enc_features = [
|
| 174 |
-
dec_features = [
|
| 175 |
nb_features = [enc_features, dec_features]
|
| 176 |
-
if re.search('^UW|SEGGUIDED_', MODEL_FILE):
|
| 177 |
network = vxm.networks.VxmDenseSemiSupervisedSeg(inshape=image_output_shape,
|
| 178 |
nb_labels=nb_labels,
|
| 179 |
nb_unet_features=nb_features,
|
|
@@ -181,6 +182,7 @@ if __name__ == '__main__':
|
|
| 181 |
int_downsize=1,
|
| 182 |
seg_downsize=1)
|
| 183 |
else:
|
|
|
|
| 184 |
network = vxm.networks.VxmDense(inshape=image_output_shape,
|
| 185 |
nb_unet_features=nb_features,
|
| 186 |
int_steps=0)
|
|
@@ -205,9 +207,9 @@ if __name__ == '__main__':
|
|
| 205 |
with h5py.File(in_batch, 'r') as f:
|
| 206 |
fix_img = f['fix_image'][:][np.newaxis, ...] # Add batch axis
|
| 207 |
mov_img = f['mov_image'][:][np.newaxis, ...]
|
| 208 |
-
fix_seg = f['fix_segmentations'][:][np.newaxis, ...].astype(np.float32)
|
| 209 |
-
mov_seg = f['mov_segmentations'][:][np.newaxis, ...].astype(np.float32)
|
| 210 |
-
fix_centroids = f['fix_centroids'][
|
| 211 |
isotropic_shape = f['isotropic_shape'][:]
|
| 212 |
voxel_size = np.divide(fix_img.shape[1:-1], isotropic_shape)
|
| 213 |
|
|
@@ -238,6 +240,7 @@ if __name__ == '__main__':
|
|
| 238 |
# dice, hd, dice_macro = sess.run([dice_tf, hd_tf, dice_macro_tf], {'fix_seg:0': fix_seg, 'pred_seg:0': pred_seg})
|
| 239 |
dice = np.mean([medpy_metrics.dc(pred_seg_isot[0, ..., l], fix_seg_isot[0, ..., l]) / np.sum(fix_seg_isot[0, ..., l]) for l in range(nb_labels)])
|
| 240 |
hd = np.mean(safe_medpy_metric(pred_seg_isot[0, ...], fix_seg_isot[0, ...], nb_labels, medpy_metrics.hd, {'voxelspacing': voxel_size}))
|
|
|
|
| 241 |
dice_macro = np.mean([medpy_metrics.dc(pred_seg_isot[0, ..., l], fix_seg_isot[0, ..., l]) for l in range(nb_labels)])
|
| 242 |
|
| 243 |
pred_seg_card = segmentation_ohe_to_cardinal(pred_seg).astype(np.float32)
|
|
@@ -261,7 +264,7 @@ if __name__ == '__main__':
|
|
| 261 |
tre = np.mean([v for v in tre_array if not np.isnan(v)])
|
| 262 |
# ['File', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'HD', 'Time', 'TRE', 'No_missing_lbls', 'Missing_lbls']
|
| 263 |
|
| 264 |
-
new_line = [step, ssim, ms_ssim, ncc, mse, dice, dice_macro, hd, t1-t0, tre, len(missing_lbls), missing_lbls]
|
| 265 |
with open(metrics_file, 'a') as f:
|
| 266 |
f.write(';'.join(map(str, new_line))+'\n')
|
| 267 |
|
|
@@ -337,12 +340,13 @@ if __name__ == '__main__':
|
|
| 337 |
with open(metrics_file_fr, 'a') as f:
|
| 338 |
f.write(';'.join(map(str, new_line)) + '\n')
|
| 339 |
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
|
|
|
| 346 |
|
| 347 |
# with h5py.File(os.path.join(output_folder, '{:03d}_centroids.h5'.format(step)), 'w') as f:
|
| 348 |
# f.create_dataset('fix_centroids', dtype=np.float32, data=fix_centroids)
|
|
|
|
| 63 |
parser.add_argument('--erase', type=bool, help='Erase the content of the output folder', default=False)
|
| 64 |
parser.add_argument('--outdirname', type=str, default='Evaluate')
|
| 65 |
parser.add_argument('--fullres', action='store_true', default=False)
|
| 66 |
+
parser.add_argument('--savenifti', type=bool, default=True)
|
| 67 |
args = parser.parse_args()
|
| 68 |
if args.model is not None:
|
| 69 |
assert '.h5' in args.model[0], 'No checkpoint file provided, use -d/--dir instead'
|
|
|
|
| 96 |
|
| 97 |
with h5py.File(list_test_files[0], 'r') as f:
|
| 98 |
image_input_shape = image_output_shape = list(f['fix_image'][:].shape[:-1])
|
| 99 |
+
nb_labels = f['fix_segmentations'][:].shape[-1] - 1 # Skip background label
|
| 100 |
|
| 101 |
# Header of the metrics csv file
|
| 102 |
+
csv_header = ['File', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'DICE_MACRO', 'HD', 'HD95', 'Time', 'TRE', 'No_missing_lbls', 'Missing_lbls']
|
| 103 |
|
| 104 |
# TF stuff
|
| 105 |
config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
|
|
|
|
| 164 |
print('DESTINATION FOLDER FULL RESOLUTION: ', output_folder_fr)
|
| 165 |
|
| 166 |
try:
|
| 167 |
+
network = tf.keras.models.load_model(MODEL_FILE, {#'VxmDenseSemiSupervisedSeg': vxm.networks.VxmDenseSemiSupervisedSeg,
|
| 168 |
'VxmDense': vxm.networks.VxmDense,
|
| 169 |
'AdamAccumulated': AdamAccumulated,
|
| 170 |
'loss': loss_fncs,
|
| 171 |
'metric': metric_fncs},
|
| 172 |
compile=False)
|
| 173 |
except ValueError as e:
|
| 174 |
+
enc_features = [32, 64, 128, 256, 512, 1024] # const.ENCODER_FILTERS
|
| 175 |
+
dec_features = enc_features[::-1] + [16, 16] # const.ENCODER_FILTERS[::-1]
|
| 176 |
nb_features = [enc_features, dec_features]
|
| 177 |
+
if False: #re.search('^UW|SEGGUIDED_', MODEL_FILE):
|
| 178 |
network = vxm.networks.VxmDenseSemiSupervisedSeg(inshape=image_output_shape,
|
| 179 |
nb_labels=nb_labels,
|
| 180 |
nb_unet_features=nb_features,
|
|
|
|
| 182 |
int_downsize=1,
|
| 183 |
seg_downsize=1)
|
| 184 |
else:
|
| 185 |
+
# only load the weights into the same model. To get the same runtime
|
| 186 |
network = vxm.networks.VxmDense(inshape=image_output_shape,
|
| 187 |
nb_unet_features=nb_features,
|
| 188 |
int_steps=0)
|
|
|
|
| 207 |
with h5py.File(in_batch, 'r') as f:
|
| 208 |
fix_img = f['fix_image'][:][np.newaxis, ...] # Add batch axis
|
| 209 |
mov_img = f['mov_image'][:][np.newaxis, ...]
|
| 210 |
+
fix_seg = f['fix_segmentations'][..., 1:][np.newaxis, ...].astype(np.float32)
|
| 211 |
+
mov_seg = f['mov_segmentations'][..., 1:][np.newaxis, ...].astype(np.float32)
|
| 212 |
+
fix_centroids = f['fix_centroids'][1:, ...]
|
| 213 |
isotropic_shape = f['isotropic_shape'][:]
|
| 214 |
voxel_size = np.divide(fix_img.shape[1:-1], isotropic_shape)
|
| 215 |
|
|
|
|
| 240 |
# dice, hd, dice_macro = sess.run([dice_tf, hd_tf, dice_macro_tf], {'fix_seg:0': fix_seg, 'pred_seg:0': pred_seg})
|
| 241 |
dice = np.mean([medpy_metrics.dc(pred_seg_isot[0, ..., l], fix_seg_isot[0, ..., l]) / np.sum(fix_seg_isot[0, ..., l]) for l in range(nb_labels)])
|
| 242 |
hd = np.mean(safe_medpy_metric(pred_seg_isot[0, ...], fix_seg_isot[0, ...], nb_labels, medpy_metrics.hd, {'voxelspacing': voxel_size}))
|
| 243 |
+
hd95 = np.mean(safe_medpy_metric(pred_seg_isot[0, ...], fix_seg_isot[0, ...], nb_labels, medpy_metrics.hd95, {'voxelspacing': voxel_size}))
|
| 244 |
dice_macro = np.mean([medpy_metrics.dc(pred_seg_isot[0, ..., l], fix_seg_isot[0, ..., l]) for l in range(nb_labels)])
|
| 245 |
|
| 246 |
pred_seg_card = segmentation_ohe_to_cardinal(pred_seg).astype(np.float32)
|
|
|
|
| 264 |
tre = np.mean([v for v in tre_array if not np.isnan(v)])
|
| 265 |
# ['File', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'HD', 'Time', 'TRE', 'No_missing_lbls', 'Missing_lbls']
|
| 266 |
|
| 267 |
+
new_line = [step, ssim, ms_ssim, ncc, mse, dice, dice_macro, hd, hd95, t1-t0, tre, len(missing_lbls), missing_lbls]
|
| 268 |
with open(metrics_file, 'a') as f:
|
| 269 |
f.write(';'.join(map(str, new_line))+'\n')
|
| 270 |
|
|
|
|
| 340 |
with open(metrics_file_fr, 'a') as f:
|
| 341 |
f.write(';'.join(map(str, new_line)) + '\n')
|
| 342 |
|
| 343 |
+
if args.savenifti:
|
| 344 |
+
save_nifti(fix_img[0, ...], os.path.join(output_folder_fr, '{:03d}_fix_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
| 345 |
+
save_nifti(mov_img[0, ...], os.path.join(output_folder_fr, '{:03d}_mov_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
| 346 |
+
save_nifti(pred_img[0, ...], os.path.join(output_folder_fr, '{:03d}_pred_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
| 347 |
+
save_nifti(fix_seg_card[0, ...], os.path.join(output_folder_fr, '{:03d}_fix_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
| 348 |
+
save_nifti(mov_seg_card[0, ...], os.path.join(output_folder_fr, '{:03d}_mov_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
| 349 |
+
save_nifti(pred_seg_card[0, ...], os.path.join(output_folder_fr, '{:03d}_pred_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
| 350 |
|
| 351 |
# with h5py.File(os.path.join(output_folder, '{:03d}_centroids.h5'.format(step)), 'w') as f:
|
| 352 |
# f.create_dataset('fix_centroids', dtype=np.float32, data=fix_centroids)
|