Upload 69 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- CVPR25_TextSegFMData_with_class.json +0 -0
- config_CT.json +93 -0
- config_nonCT.json +13 -0
- data/__init__.py +0 -0
- data/default_resampling.py +208 -0
- data/resample_torch.py +162 -0
- data/resampling_test.py +593 -0
- environment.yml +211 -0
- evaluate/SurfaceDice.py +492 -0
- evaluate/__init__.py +0 -0
- evaluate/evaluator.py +379 -0
- evaluate/merge_after_evaluate.py +198 -0
- evaluate/metric.py +46 -0
- evaluate/params.py +153 -0
- inference_medals_nifti.py +1885 -0
- model/SwinUNETR.py +1116 -0
- model/__init__.py +0 -0
- model/base_bert.py +26 -0
- model/build_model.py +103 -0
- model/dynamic-network-architectures-main/.gitignore +113 -0
- model/dynamic-network-architectures-main/LICENCE +201 -0
- model/dynamic-network-architectures-main/README.md +25 -0
- model/dynamic-network-architectures-main/dynamic_network_architectures.egg-info/PKG-INFO +16 -0
- model/dynamic-network-architectures-main/dynamic_network_architectures.egg-info/SOURCES.txt +24 -0
- model/dynamic-network-architectures-main/dynamic_network_architectures.egg-info/dependency_links.txt +1 -0
- model/dynamic-network-architectures-main/dynamic_network_architectures.egg-info/not-zip-safe +1 -0
- model/dynamic-network-architectures-main/dynamic_network_architectures.egg-info/requires.txt +2 -0
- model/dynamic-network-architectures-main/dynamic_network_architectures.egg-info/top_level.txt +1 -0
- model/dynamic-network-architectures-main/dynamic_network_architectures/__init__.py +0 -0
- model/dynamic-network-architectures-main/dynamic_network_architectures/__pycache__/__init__.cpython-310.pyc +0 -0
- model/dynamic-network-architectures-main/dynamic_network_architectures/architectures/__init__.py +0 -0
- model/dynamic-network-architectures-main/dynamic_network_architectures/architectures/__pycache__/__init__.cpython-310.pyc +0 -0
- model/dynamic-network-architectures-main/dynamic_network_architectures/architectures/__pycache__/unet.cpython-310.pyc +0 -0
- model/dynamic-network-architectures-main/dynamic_network_architectures/architectures/resnet.py +236 -0
- model/dynamic-network-architectures-main/dynamic_network_architectures/architectures/unet.py +220 -0
- model/dynamic-network-architectures-main/dynamic_network_architectures/architectures/vgg.py +85 -0
- model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__init__.py +0 -0
- model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/__init__.cpython-310.pyc +0 -0
- model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/helper.cpython-310.pyc +0 -0
- model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/plain_conv_encoder.cpython-310.pyc +0 -0
- model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/regularization.cpython-310.pyc +0 -0
- model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/residual.cpython-310.pyc +0 -0
- model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/residual_encoders.cpython-310.pyc +0 -0
- model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/simple_conv_blocks.cpython-310.pyc +0 -0
- model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/unet_decoder.cpython-310.pyc +0 -0
- model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/helper.py +242 -0
- model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/plain_conv_encoder.py +105 -0
- model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/regularization.py +86 -0
- model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/residual.py +371 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
model/dynamic-network-architectures-main/imgs/Logos/HI_Logo.png filter=lfs diff=lfs merge=lfs -text
|
CVPR25_TextSegFMData_with_class.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
config_CT.json
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"texts_soft_tissue": [
|
| 3 |
+
"Aorta in whole body CT",
|
| 4 |
+
"gallbladder in whole body CT",
|
| 5 |
+
"left kidney in whole body CT",
|
| 6 |
+
"right kidney in whole body CT",
|
| 7 |
+
"liver in whole body CT",
|
| 8 |
+
"Pancreas in whole body CT",
|
| 9 |
+
"Spleen in whole body CT",
|
| 10 |
+
"stomach in whole body CT",
|
| 11 |
+
"Left adrenal gland in whole body CT",
|
| 12 |
+
"right adrenal gland in whole body CT",
|
| 13 |
+
"Bladder in whole body CT",
|
| 14 |
+
"Esophagus in whole body CT",
|
| 15 |
+
"Heart in whole body CT",
|
| 16 |
+
"Pulmonary vein in whole body CT",
|
| 17 |
+
"Brachiocephalic trunk in whole body CT",
|
| 18 |
+
"Right subclavian artery in whole body CT",
|
| 19 |
+
"Left subclavian artery in whole body CT",
|
| 20 |
+
"Right common carotid artery in whole body CT",
|
| 21 |
+
"Left common carotid artery in whole body CT",
|
| 22 |
+
"Left brachiocephalic vein in whole body CT",
|
| 23 |
+
"Right brachiocephalic vein in whole body CT",
|
| 24 |
+
"Left atrial appendage in whole body CT",
|
| 25 |
+
"Superior vena cava in whole body CT",
|
| 26 |
+
"Inferior vena cava in whole body CT",
|
| 27 |
+
"Portal vein and splenic vein in whole body CT",
|
| 28 |
+
"Left iliac artery in whole body CT",
|
| 29 |
+
"Right iliac artery in whole body CT",
|
| 30 |
+
"Left iliac vena in whole body CT",
|
| 31 |
+
"Right iliac vena in whole body CT",
|
| 32 |
+
"Spinal cord in whole body CT",
|
| 33 |
+
"Left gluteus Maximus in whole body CT",
|
| 34 |
+
"Right gluteus Maximus in whole body CT",
|
| 35 |
+
"Left gluteus Medius in whole body CT",
|
| 36 |
+
"Right gluteus Medius in whole body CT",
|
| 37 |
+
"Left gluteus Minimus in whole body CT",
|
| 38 |
+
"Right gluteus Minimus in whole body CT",
|
| 39 |
+
"Left autochthon in whole body CT",
|
| 40 |
+
"Right autochthon in whole body CT",
|
| 41 |
+
"Left iliopsoas in whole body CT",
|
| 42 |
+
"Right iliopsoas in whole body CT"
|
| 43 |
+
],
|
| 44 |
+
"texts_bone": [
|
| 45 |
+
"Vertebrae C7 in whole body CT",
|
| 46 |
+
"Vertebrae C6 in whole body CT",
|
| 47 |
+
"Vertebrae C5 in whole body CT",
|
| 48 |
+
"Vertebrae C4 in whole body CT",
|
| 49 |
+
"Vertebrae C3 in whole body CT",
|
| 50 |
+
"Vertebrae C2 in whole body CT",
|
| 51 |
+
"Vertebrae C1 in whole body CT",
|
| 52 |
+
"Vertebrae T12 in whole body CT",
|
| 53 |
+
"Vertebrae T11 in whole body CT",
|
| 54 |
+
"Vertebrae T10 in whole body CT",
|
| 55 |
+
"Vertebrae T9 in whole body CT",
|
| 56 |
+
"Vertebrae T8 in whole body CT",
|
| 57 |
+
"Vertebrae T7 in whole body CT",
|
| 58 |
+
"Vertebrae T6 in whole body CT",
|
| 59 |
+
"Vertebrae T5 in whole body CT",
|
| 60 |
+
"Vertebrae T4 in whole body CT",
|
| 61 |
+
"Vertebrae T3 in whole body CT",
|
| 62 |
+
"Vertebrae T2 in whole body CT",
|
| 63 |
+
"Vertebrae T1 in whole body CT",
|
| 64 |
+
"Left humerus in whole body CT",
|
| 65 |
+
"Right humerus in whole body CT",
|
| 66 |
+
"Left clavicula in whole body CT",
|
| 67 |
+
"Right clavicula in whole body CT",
|
| 68 |
+
"Left femur in whole body CT",
|
| 69 |
+
"Right femur in whole body CT",
|
| 70 |
+
"Left hip in whole body CT",
|
| 71 |
+
"Right hip in whole body CT"
|
| 72 |
+
],
|
| 73 |
+
"texts_lung": [
|
| 74 |
+
"Left lung in whole body CT",
|
| 75 |
+
"Right lung in whole body CT"
|
| 76 |
+
],
|
| 77 |
+
"window_settings": {
|
| 78 |
+
"soft_tissue": {
|
| 79 |
+
"window_level": 40,
|
| 80 |
+
"window_width": 400
|
| 81 |
+
},
|
| 82 |
+
"bone": {
|
| 83 |
+
"window_level": 500,
|
| 84 |
+
"window_width": 1500
|
| 85 |
+
},
|
| 86 |
+
"lung": {
|
| 87 |
+
"window_level": -600,
|
| 88 |
+
"window_width": 1500
|
| 89 |
+
}
|
| 90 |
+
},
|
| 91 |
+
"modality": "CT",
|
| 92 |
+
"instance_label": 0
|
| 93 |
+
}
|
config_nonCT.json
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"texts": [
|
| 3 |
+
"Spleen in MRI"
|
| 4 |
+
],
|
| 5 |
+
"normalization_settings": {
|
| 6 |
+
"percentile_lower": 0.5,
|
| 7 |
+
"percentile_upper": 99.5,
|
| 8 |
+
"preserve_zero": true
|
| 9 |
+
},
|
| 10 |
+
"modality": "MRI",
|
| 11 |
+
"instance_label": 0
|
| 12 |
+
}
|
| 13 |
+
|
data/__init__.py
ADDED
|
File without changes
|
data/default_resampling.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
from copy import deepcopy
|
| 3 |
+
from typing import Union, Tuple, List
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import sklearn
|
| 8 |
+
import torch
|
| 9 |
+
from batchgenerators.augmentations.utils import resize_segmentation
|
| 10 |
+
from scipy.ndimage import map_coordinates
|
| 11 |
+
from skimage.transform import resize
|
| 12 |
+
|
| 13 |
+
ANISO_THRESHOLD = 3 # determines when a sample is considered anisotropic (3 means that the spacing in the low
|
| 14 |
+
# resolution axis must be 3x as large as the next largest spacing)
|
| 15 |
+
|
| 16 |
+
def get_do_separate_z(spacing: Union[Tuple[float, ...], List[float], np.ndarray], anisotropy_threshold=ANISO_THRESHOLD):
|
| 17 |
+
do_separate_z = (np.max(spacing) / np.min(spacing)) > anisotropy_threshold
|
| 18 |
+
return do_separate_z
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_lowres_axis(new_spacing: Union[Tuple[float, ...], List[float], np.ndarray]):
|
| 22 |
+
axis = np.where(max(new_spacing) / np.array(new_spacing) == 1)[0] # find which axis is anisotropic
|
| 23 |
+
return axis
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def compute_new_shape(old_shape: Union[Tuple[int, ...], List[int], np.ndarray],
|
| 27 |
+
old_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
|
| 28 |
+
new_spacing: Union[Tuple[float, ...], List[float], np.ndarray]) -> np.ndarray:
|
| 29 |
+
assert len(old_spacing) == len(old_shape)
|
| 30 |
+
assert len(old_shape) == len(new_spacing)
|
| 31 |
+
new_shape = np.array([int(round(i / j * k)) for i, j, k in zip(old_spacing, new_spacing, old_shape)])
|
| 32 |
+
return new_shape
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def determine_do_sep_z_and_axis(
|
| 37 |
+
force_separate_z: bool,
|
| 38 |
+
current_spacing,
|
| 39 |
+
new_spacing,
|
| 40 |
+
separate_z_anisotropy_threshold: float = ANISO_THRESHOLD) -> Tuple[bool, Union[int, None]]:
|
| 41 |
+
if force_separate_z is not None:
|
| 42 |
+
do_separate_z = force_separate_z
|
| 43 |
+
if force_separate_z:
|
| 44 |
+
axis = get_lowres_axis(current_spacing)
|
| 45 |
+
else:
|
| 46 |
+
axis = None
|
| 47 |
+
else:
|
| 48 |
+
if get_do_separate_z(current_spacing, separate_z_anisotropy_threshold):
|
| 49 |
+
do_separate_z = True
|
| 50 |
+
axis = get_lowres_axis(current_spacing)
|
| 51 |
+
elif get_do_separate_z(new_spacing, separate_z_anisotropy_threshold):
|
| 52 |
+
do_separate_z = True
|
| 53 |
+
axis = get_lowres_axis(new_spacing)
|
| 54 |
+
else:
|
| 55 |
+
do_separate_z = False
|
| 56 |
+
axis = None
|
| 57 |
+
|
| 58 |
+
if axis is not None:
|
| 59 |
+
if len(axis) == 3:
|
| 60 |
+
do_separate_z = False
|
| 61 |
+
axis = None
|
| 62 |
+
elif len(axis) == 2:
|
| 63 |
+
# this happens for spacings like (0.24, 1.25, 1.25) for example. In that case we do not want to resample
|
| 64 |
+
# separately in the out of plane axis
|
| 65 |
+
do_separate_z = False
|
| 66 |
+
axis = None
|
| 67 |
+
else:
|
| 68 |
+
axis = axis[0]
|
| 69 |
+
return do_separate_z, axis
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def resample_data_or_seg_to_spacing(data: np.ndarray,
|
| 73 |
+
current_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
|
| 74 |
+
new_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
|
| 75 |
+
is_seg: bool = False,
|
| 76 |
+
order: int = 3, order_z: int = 0,
|
| 77 |
+
force_separate_z: Union[bool, None] = False,
|
| 78 |
+
separate_z_anisotropy_threshold: float = ANISO_THRESHOLD):
|
| 79 |
+
do_separate_z, axis = determine_do_sep_z_and_axis(force_separate_z, current_spacing, new_spacing,
|
| 80 |
+
separate_z_anisotropy_threshold)
|
| 81 |
+
|
| 82 |
+
if data is not None:
|
| 83 |
+
assert data.ndim == 4, "data must be c x y z"
|
| 84 |
+
|
| 85 |
+
shape = np.array(data.shape)
|
| 86 |
+
new_shape = compute_new_shape(shape[1:], current_spacing, new_spacing)
|
| 87 |
+
|
| 88 |
+
data_reshaped = resample_data_or_seg(data, new_shape, is_seg, axis, order, do_separate_z, order_z=order_z)
|
| 89 |
+
return data_reshaped
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def resample_data_or_seg_to_shape(data: Union[torch.Tensor, np.ndarray],
|
| 93 |
+
new_shape: Union[Tuple[int, ...], List[int], np.ndarray],
|
| 94 |
+
current_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
|
| 95 |
+
new_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
|
| 96 |
+
is_seg: bool = False,
|
| 97 |
+
order: int = 3, order_z: int = 0,
|
| 98 |
+
force_separate_z: Union[bool, None] = False,
|
| 99 |
+
separate_z_anisotropy_threshold: float = ANISO_THRESHOLD):
|
| 100 |
+
"""
|
| 101 |
+
needed for segmentation export. Stupid, I know
|
| 102 |
+
"""
|
| 103 |
+
if isinstance(data, torch.Tensor):
|
| 104 |
+
data = data.numpy()
|
| 105 |
+
|
| 106 |
+
do_separate_z, axis = determine_do_sep_z_and_axis(force_separate_z, current_spacing, new_spacing,
|
| 107 |
+
separate_z_anisotropy_threshold)
|
| 108 |
+
|
| 109 |
+
if data is not None:
|
| 110 |
+
assert data.ndim == 4, "data must be c x y z"
|
| 111 |
+
|
| 112 |
+
data_reshaped = resample_data_or_seg(data, new_shape, is_seg, axis, order, do_separate_z, order_z=order_z)
|
| 113 |
+
return data_reshaped
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def resample_data_or_seg(data: np.ndarray, new_shape: Union[Tuple[float, ...], List[float], np.ndarray],
|
| 117 |
+
is_seg: bool = False, axis: Union[None, int] = None, order: int = 3,
|
| 118 |
+
do_separate_z: bool = False, order_z: int = 0, dtype_out = None):
|
| 119 |
+
"""
|
| 120 |
+
separate_z=True will resample with order 0 along z
|
| 121 |
+
:param data:
|
| 122 |
+
:param new_shape:
|
| 123 |
+
:param is_seg:
|
| 124 |
+
:param axis:
|
| 125 |
+
:param order:
|
| 126 |
+
:param do_separate_z:
|
| 127 |
+
:param order_z: only applies if do_separate_z is True
|
| 128 |
+
:return:
|
| 129 |
+
"""
|
| 130 |
+
assert data.ndim == 4, "data must be (c, x, y, z)"
|
| 131 |
+
assert len(new_shape) == data.ndim - 1
|
| 132 |
+
|
| 133 |
+
if is_seg:
|
| 134 |
+
resize_fn = resize_segmentation
|
| 135 |
+
kwargs = OrderedDict()
|
| 136 |
+
else:
|
| 137 |
+
resize_fn = resize
|
| 138 |
+
kwargs = {'mode': 'edge', 'anti_aliasing': False}
|
| 139 |
+
shape = np.array(data[0].shape)
|
| 140 |
+
new_shape = np.array(new_shape)
|
| 141 |
+
if dtype_out is None:
|
| 142 |
+
dtype_out = data.dtype
|
| 143 |
+
reshaped_final = np.zeros((data.shape[0], *new_shape), dtype=dtype_out)
|
| 144 |
+
if np.any(shape != new_shape):
|
| 145 |
+
data = data.astype(float, copy=False)
|
| 146 |
+
if do_separate_z:
|
| 147 |
+
# print("separate z, order in z is", order_z, "order inplane is", order)
|
| 148 |
+
assert axis is not None, 'If do_separate_z, we need to know what axis is anisotropic'
|
| 149 |
+
if axis == 0:
|
| 150 |
+
new_shape_2d = new_shape[1:]
|
| 151 |
+
elif axis == 1:
|
| 152 |
+
new_shape_2d = new_shape[[0, 2]]
|
| 153 |
+
else:
|
| 154 |
+
new_shape_2d = new_shape[:-1]
|
| 155 |
+
|
| 156 |
+
for c in range(data.shape[0]):
|
| 157 |
+
tmp = deepcopy(new_shape)
|
| 158 |
+
tmp[axis] = shape[axis]
|
| 159 |
+
reshaped_here = np.zeros(tmp)
|
| 160 |
+
for slice_id in range(shape[axis]):
|
| 161 |
+
if axis == 0:
|
| 162 |
+
reshaped_here[slice_id] = resize_fn(data[c, slice_id], new_shape_2d, order, **kwargs)
|
| 163 |
+
elif axis == 1:
|
| 164 |
+
reshaped_here[:, slice_id] = resize_fn(data[c, :, slice_id], new_shape_2d, order, **kwargs)
|
| 165 |
+
else:
|
| 166 |
+
reshaped_here[:, :, slice_id] = resize_fn(data[c, :, :, slice_id], new_shape_2d, order, **kwargs)
|
| 167 |
+
if shape[axis] != new_shape[axis]:
|
| 168 |
+
|
| 169 |
+
# The following few lines are blatantly copied and modified from sklearn's resize()
|
| 170 |
+
rows, cols, dim = new_shape[0], new_shape[1], new_shape[2]
|
| 171 |
+
orig_rows, orig_cols, orig_dim = reshaped_here.shape
|
| 172 |
+
|
| 173 |
+
# align_corners=False
|
| 174 |
+
row_scale = float(orig_rows) / rows
|
| 175 |
+
col_scale = float(orig_cols) / cols
|
| 176 |
+
dim_scale = float(orig_dim) / dim
|
| 177 |
+
|
| 178 |
+
map_rows, map_cols, map_dims = np.mgrid[:rows, :cols, :dim]
|
| 179 |
+
map_rows = row_scale * (map_rows + 0.5) - 0.5
|
| 180 |
+
map_cols = col_scale * (map_cols + 0.5) - 0.5
|
| 181 |
+
map_dims = dim_scale * (map_dims + 0.5) - 0.5
|
| 182 |
+
|
| 183 |
+
coord_map = np.array([map_rows, map_cols, map_dims])
|
| 184 |
+
if not is_seg or order_z == 0:
|
| 185 |
+
reshaped_final[c] = map_coordinates(reshaped_here, coord_map, order=order_z, mode='nearest')[None]
|
| 186 |
+
else:
|
| 187 |
+
unique_labels = np.sort(pd.unique(reshaped_here.ravel())) # np.unique(reshaped_data)
|
| 188 |
+
for i, cl in enumerate(unique_labels):
|
| 189 |
+
reshaped_final[c][np.round(
|
| 190 |
+
map_coordinates((reshaped_here == cl).astype(float), coord_map, order=order_z,
|
| 191 |
+
mode='nearest')) > 0.5] = cl
|
| 192 |
+
else:
|
| 193 |
+
reshaped_final[c] = reshaped_here
|
| 194 |
+
else:
|
| 195 |
+
# print("no separate z, order", order)
|
| 196 |
+
for c in range(data.shape[0]):
|
| 197 |
+
reshaped_final[c] = resize_fn(data[c], new_shape, order, **kwargs)
|
| 198 |
+
return reshaped_final
|
| 199 |
+
else:
|
| 200 |
+
# print("no resampling necessary")
|
| 201 |
+
return data
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
if __name__ == '__main__':
|
| 205 |
+
input_array = np.random.random((1, 42, 231, 142))
|
| 206 |
+
output_shape = (52, 256, 256)
|
| 207 |
+
out = resample_data_or_seg(input_array, output_shape, is_seg=False, axis=3, order=1, order_z=0, do_separate_z=True)
|
| 208 |
+
print(out.shape, input_array.shape)
|
data/resample_torch.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from copy import deepcopy
|
| 2 |
+
from typing import Union, Tuple, List
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from einops import rearrange
|
| 7 |
+
from torch.nn import functional as F
|
| 8 |
+
|
| 9 |
+
from data.default_resampling import determine_do_sep_z_and_axis
|
| 10 |
+
|
| 11 |
+
ANISO_THRESHOLD = 3 # determines when a sample is considered anisotropic (3 means that the spacing in the low
|
| 12 |
+
# resolution axis must be 3x as large as the next largest spacing)
|
| 13 |
+
|
| 14 |
+
def resample_torch_simple(
|
| 15 |
+
data: Union[torch.Tensor, np.ndarray],
|
| 16 |
+
new_shape: Union[Tuple[int, ...], List[int], np.ndarray],
|
| 17 |
+
is_seg: bool = False,
|
| 18 |
+
num_threads: int = 4,
|
| 19 |
+
device: torch.device = torch.device('cpu'),
|
| 20 |
+
memefficient_seg_resampling: bool = False,
|
| 21 |
+
mode='linear'
|
| 22 |
+
):
|
| 23 |
+
if mode == 'linear':
|
| 24 |
+
if data.ndim == 4:
|
| 25 |
+
torch_mode = 'trilinear'
|
| 26 |
+
elif data.ndim == 3:
|
| 27 |
+
torch_mode = 'bilinear'
|
| 28 |
+
else:
|
| 29 |
+
raise RuntimeError
|
| 30 |
+
else:
|
| 31 |
+
torch_mode = mode
|
| 32 |
+
|
| 33 |
+
if isinstance(new_shape, np.ndarray):
|
| 34 |
+
new_shape = [int(i) for i in new_shape]
|
| 35 |
+
|
| 36 |
+
if all([i == j for i, j in zip(new_shape, data.shape[1:])]):
|
| 37 |
+
return data
|
| 38 |
+
else:
|
| 39 |
+
n_threads = torch.get_num_threads()
|
| 40 |
+
torch.set_num_threads(num_threads)
|
| 41 |
+
new_shape = tuple(new_shape)
|
| 42 |
+
with torch.no_grad():
|
| 43 |
+
|
| 44 |
+
input_was_numpy = isinstance(data, np.ndarray)
|
| 45 |
+
if input_was_numpy:
|
| 46 |
+
data = torch.from_numpy(data).to(device)
|
| 47 |
+
else:
|
| 48 |
+
orig_device = deepcopy(data.device)
|
| 49 |
+
data = data.to(device)
|
| 50 |
+
|
| 51 |
+
if is_seg:
|
| 52 |
+
unique_values = torch.unique(data)
|
| 53 |
+
result_dtype = torch.int8 if max(unique_values) < 127 else torch.int16
|
| 54 |
+
result = torch.zeros((data.shape[0], *new_shape), dtype=result_dtype, device=device)
|
| 55 |
+
if not memefficient_seg_resampling:
|
| 56 |
+
# believe it or not, the implementation below is 3x as fast (at least on Liver CT and on CPU)
|
| 57 |
+
# Why? Because argmax is slow. The implementation below immediately sets most locations and only lets the
|
| 58 |
+
# uncertain ones be determined by argmax
|
| 59 |
+
|
| 60 |
+
# unique_values = torch.unique(data)
|
| 61 |
+
# result = torch.zeros((len(unique_values), data.shape[0], *new_shape), dtype=torch.float16)
|
| 62 |
+
# for i, u in enumerate(unique_values):
|
| 63 |
+
# result[i] = F.interpolate((data[None] == u).float() * 1000, new_shape, mode='trilinear', antialias=False)[0]
|
| 64 |
+
# result = unique_values[result.argmax(0)]
|
| 65 |
+
|
| 66 |
+
result_tmp = torch.zeros((len(unique_values), data.shape[0], *new_shape), dtype=torch.float16,
|
| 67 |
+
device=device)
|
| 68 |
+
scale_factor = 1000
|
| 69 |
+
done_mask = torch.zeros_like(result, dtype=torch.bool, device=device)
|
| 70 |
+
for i, u in enumerate(unique_values):
|
| 71 |
+
result_tmp[i] = \
|
| 72 |
+
F.interpolate((data[None] == u).float() * scale_factor, new_shape, mode=torch_mode,
|
| 73 |
+
antialias=False)[0]
|
| 74 |
+
mask = result_tmp[i] > (0.7 * scale_factor)
|
| 75 |
+
result[mask] = u.item()
|
| 76 |
+
done_mask |= mask
|
| 77 |
+
if not torch.all(done_mask):
|
| 78 |
+
# print('resolving argmax', torch.sum(~done_mask), "voxels to go")
|
| 79 |
+
result[~done_mask] = unique_values[result_tmp[:, ~done_mask].argmax(0)].to(result_dtype)
|
| 80 |
+
else:
|
| 81 |
+
for i, u in enumerate(unique_values):
|
| 82 |
+
if u == 0:
|
| 83 |
+
pass
|
| 84 |
+
result[F.interpolate((data[None] == u).float(), new_shape, mode=torch_mode, antialias=False)[
|
| 85 |
+
0] > 0.5] = u
|
| 86 |
+
else:
|
| 87 |
+
result = F.interpolate(data[None].float(), new_shape, mode=torch_mode, antialias=False)[0]
|
| 88 |
+
if input_was_numpy:
|
| 89 |
+
result = result.cpu().numpy()
|
| 90 |
+
else:
|
| 91 |
+
result = result.to(orig_device)
|
| 92 |
+
torch.set_num_threads(n_threads)
|
| 93 |
+
return result
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def resample_torch_fornnunet(
|
| 97 |
+
data: Union[torch.Tensor, np.ndarray],
|
| 98 |
+
new_shape: Union[Tuple[int, ...], List[int], np.ndarray],
|
| 99 |
+
current_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
|
| 100 |
+
new_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
|
| 101 |
+
is_seg: bool = False,
|
| 102 |
+
num_threads: int = 4,
|
| 103 |
+
device: torch.device = torch.device('cpu'),
|
| 104 |
+
memefficient_seg_resampling: bool = False,
|
| 105 |
+
force_separate_z: Union[bool, None] = None,
|
| 106 |
+
separate_z_anisotropy_threshold: float = ANISO_THRESHOLD,
|
| 107 |
+
mode='linear',
|
| 108 |
+
aniso_axis_mode='nearest-exact'
|
| 109 |
+
):
|
| 110 |
+
"""
|
| 111 |
+
data must be c, x, y, z
|
| 112 |
+
"""
|
| 113 |
+
assert data.ndim == 4, "data must be c, x, y, z"
|
| 114 |
+
new_shape = [int(i) for i in new_shape]
|
| 115 |
+
orig_shape = data.shape
|
| 116 |
+
|
| 117 |
+
do_separate_z, axis = determine_do_sep_z_and_axis(force_separate_z, current_spacing, new_spacing,
|
| 118 |
+
separate_z_anisotropy_threshold)
|
| 119 |
+
# print('shape', data.shape, 'current_spacing', current_spacing, 'new_spacing', new_spacing, 'do_separate_z', do_separate_z, 'axis', axis)
|
| 120 |
+
|
| 121 |
+
if do_separate_z:
|
| 122 |
+
was_numpy = isinstance(data, np.ndarray)
|
| 123 |
+
if was_numpy:
|
| 124 |
+
data = torch.from_numpy(data)
|
| 125 |
+
|
| 126 |
+
if isinstance(axis, list):
|
| 127 |
+
assert len(axis) == 1
|
| 128 |
+
axis = axis[0]
|
| 129 |
+
else:
|
| 130 |
+
pass
|
| 131 |
+
|
| 132 |
+
tmp = "xyz"
|
| 133 |
+
axis_letter = tmp[axis]
|
| 134 |
+
others_int = [i for i in range(3) if i != axis]
|
| 135 |
+
others = [tmp[i] for i in others_int]
|
| 136 |
+
|
| 137 |
+
# reshape by overloading c channel
|
| 138 |
+
data = rearrange(data, f"c x y z -> (c {axis_letter}) {others[0]} {others[1]}")
|
| 139 |
+
|
| 140 |
+
# reshape in-plane
|
| 141 |
+
tmp_new_shape = [new_shape[i] for i in others_int]
|
| 142 |
+
data = resample_torch_simple(data, tmp_new_shape, is_seg=is_seg, num_threads=num_threads, device=device,
|
| 143 |
+
memefficient_seg_resampling=memefficient_seg_resampling, mode=mode)
|
| 144 |
+
data = rearrange(data, f"(c {axis_letter}) {others[0]} {others[1]} -> c x y z",
|
| 145 |
+
**{
|
| 146 |
+
axis_letter: orig_shape[axis + 1],
|
| 147 |
+
others[0]: tmp_new_shape[0],
|
| 148 |
+
others[1]: tmp_new_shape[1]
|
| 149 |
+
}
|
| 150 |
+
)
|
| 151 |
+
# reshape out of plane w/ nearest
|
| 152 |
+
data = resample_torch_simple(data, new_shape, is_seg=is_seg, num_threads=num_threads, device=device,
|
| 153 |
+
memefficient_seg_resampling=memefficient_seg_resampling, mode=aniso_axis_mode)
|
| 154 |
+
if was_numpy:
|
| 155 |
+
data = data.numpy()
|
| 156 |
+
return data
|
| 157 |
+
else:
|
| 158 |
+
return resample_torch_simple(data, new_shape, is_seg, num_threads, device, memefficient_seg_resampling)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
if __name__ == '__main__':
|
| 162 |
+
torch.set_num_threads(16)
|
data/resampling_test.py
ADDED
|
@@ -0,0 +1,593 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Union, Tuple, List
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from einops import rearrange
|
| 6 |
+
import time
|
| 7 |
+
from copy import deepcopy
|
| 8 |
+
from default_resampling import determine_do_sep_z_and_axis
|
| 9 |
+
import psutil
|
| 10 |
+
import nibabel as nib
|
| 11 |
+
import os
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
ANISO_THRESHOLD = 3
|
| 15 |
+
|
| 16 |
+
def compute_new_shape(current_shape: Union[Tuple[int, ...], List[int], np.ndarray],
|
| 17 |
+
current_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
|
| 18 |
+
target_spacing: Union[Tuple[float, ...], List[float], np.ndarray]) -> List[int]:
|
| 19 |
+
"""Compute new shape based on spacing ratios."""
|
| 20 |
+
current_shape = np.array(current_shape)
|
| 21 |
+
current_spacing = np.array(current_spacing)
|
| 22 |
+
target_spacing = np.array(target_spacing)
|
| 23 |
+
return [int(round(s * (cs / ts))) for s, cs, ts in zip(current_shape, current_spacing, target_spacing)]
|
| 24 |
+
|
| 25 |
+
def optimized_3d_resample(
|
| 26 |
+
data: Union[torch.Tensor, np.ndarray],
|
| 27 |
+
current_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
|
| 28 |
+
target_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
|
| 29 |
+
is_seg: bool = False,
|
| 30 |
+
device: torch.device = torch.device('cpu'),
|
| 31 |
+
num_threads: int = 8,
|
| 32 |
+
chunk_size: int = 64,
|
| 33 |
+
force_separate_z: Union[bool, None] = None,
|
| 34 |
+
separate_z_anisotropy_threshold: float = ANISO_THRESHOLD,
|
| 35 |
+
preserve_range: bool = True
|
| 36 |
+
) -> Union[torch.Tensor, np.ndarray]:
|
| 37 |
+
"""
|
| 38 |
+
Optimized 3D image resampling with adaptive interpolation and chunked processing.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
data: Input 3D volume [C, D, H, W] or [D, H, W]
|
| 42 |
+
current_spacing: Current voxel spacing (z, y, x)
|
| 43 |
+
target_spacing: Target voxel spacing (z, y, x)
|
| 44 |
+
is_seg: Whether the input is a segmentation mask
|
| 45 |
+
device: Torch device for computation
|
| 46 |
+
num_threads: Number of threads for CPU operations
|
| 47 |
+
chunk_size: Size of chunks for large volume processing
|
| 48 |
+
force_separate_z: Force separate z resampling
|
| 49 |
+
separate_z_anisotropy_threshold: Threshold for anisotropic resampling
|
| 50 |
+
preserve_range: Preserve original value range for non-segmentation data
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
Resampled 3D volume
|
| 54 |
+
"""
|
| 55 |
+
print(f"\nStarting optimized_3d_resample with input shape: {data.shape}, is_seg: {is_seg}")
|
| 56 |
+
input_was_numpy = isinstance(data, np.ndarray)
|
| 57 |
+
if input_was_numpy:
|
| 58 |
+
data = torch.from_numpy(data).to(device)
|
| 59 |
+
else:
|
| 60 |
+
data = data.to(device)
|
| 61 |
+
print(f"Input converted to tensor on {device}, shape: {data.shape}")
|
| 62 |
+
|
| 63 |
+
if data.ndim == 3:
|
| 64 |
+
data = data.unsqueeze(0)
|
| 65 |
+
assert data.ndim == 4, "Data must be 3D or 4D (C, D, H, W)"
|
| 66 |
+
|
| 67 |
+
new_shape = compute_new_shape(data.shape[1:], current_spacing, target_spacing)
|
| 68 |
+
print(f"Computed new shape: {new_shape} from current_spacing: {current_spacing}, target_spacing: {target_spacing}")
|
| 69 |
+
|
| 70 |
+
if all(i == j for i, j in zip(new_shape, data.shape[1:])):
|
| 71 |
+
print("No resampling needed, shapes identical.")
|
| 72 |
+
return data.cpu().numpy() if input_was_numpy else data
|
| 73 |
+
|
| 74 |
+
mode = 'nearest' if is_seg else 'trilinear'
|
| 75 |
+
aniso_axis_mode = 'nearest-exact' if is_seg else 'linear'
|
| 76 |
+
print(f"Interpolation mode: {mode}, Anisotropic axis mode: {aniso_axis_mode}")
|
| 77 |
+
|
| 78 |
+
do_separate_z, axis = determine_do_sep_z_and_axis(force_separate_z, current_spacing,
|
| 79 |
+
target_spacing, separate_z_anisotropy_threshold)
|
| 80 |
+
print(f"Do separate Z: {do_separate_z}, Axis: {axis}")
|
| 81 |
+
|
| 82 |
+
if preserve_range and not is_seg:
|
| 83 |
+
v_min, v_max = data.min(), data.max()
|
| 84 |
+
print(f"Preserving range for non-segmentation data: min={v_min.item():.4f}, max={v_max.item():.4f}")
|
| 85 |
+
|
| 86 |
+
torch.set_num_threads(num_threads)
|
| 87 |
+
print(f"Set number of threads to {num_threads}")
|
| 88 |
+
|
| 89 |
+
start_time = time.time()
|
| 90 |
+
if do_separate_z:
|
| 91 |
+
tmp = "xyz"
|
| 92 |
+
axis_letter = tmp[axis]
|
| 93 |
+
others_int = [i for i in range(3) if i != axis]
|
| 94 |
+
others = [tmp[i] for i in others_int]
|
| 95 |
+
print(f"Separate Z resampling along axis {axis_letter}, others: {others}")
|
| 96 |
+
|
| 97 |
+
tmp_new_shape = [new_shape[i] for i in others_int]
|
| 98 |
+
print(f"First pass: Resampling to shape {tmp_new_shape} for axes {others}")
|
| 99 |
+
data = rearrange(data, f"c x y z -> (c {axis_letter}) {others[0]} {others[1]}")
|
| 100 |
+
print(f"Rearranged data shape: {data.shape}")
|
| 101 |
+
data = _chunked_resample(data, tmp_new_shape, mode, chunk_size, device, is_seg)
|
| 102 |
+
print(f"After first pass resampling, shape: {data.shape}")
|
| 103 |
+
|
| 104 |
+
data = rearrange(data, f"(c {axis_letter}) {others[0]} {others[1]} -> c x y z",
|
| 105 |
+
**{axis_letter: data.shape[1], others[0]: tmp_new_shape[0], others[1]: tmp_new_shape[1]})
|
| 106 |
+
print(f"Rearranged back to shape: {data.shape}")
|
| 107 |
+
data = _chunked_resample(data, new_shape, aniso_axis_mode, chunk_size, device, is_seg)
|
| 108 |
+
print(f"After second pass resampling, final shape: {data.shape}")
|
| 109 |
+
else:
|
| 110 |
+
print(f"Direct resampling to shape: {new_shape}")
|
| 111 |
+
data = _chunked_resample(data, new_shape, mode, chunk_size, device, is_seg)
|
| 112 |
+
print(f"After direct resampling, final shape: {data.shape}")
|
| 113 |
+
resample_time = time.time() - start_time
|
| 114 |
+
print(f"Resampling completed in {resample_time:.3f}s")
|
| 115 |
+
|
| 116 |
+
if is_seg:
|
| 117 |
+
unique_values = torch.unique(data)
|
| 118 |
+
result_dtype = torch.int8 if max(unique_values) < 127 else torch.int16
|
| 119 |
+
data = data.round().to(result_dtype)
|
| 120 |
+
print(f"Segmentation data rounded and converted to {result_dtype}, unique values: {unique_values.tolist()}")
|
| 121 |
+
|
| 122 |
+
if preserve_range and not is_seg:
|
| 123 |
+
data = torch.clamp(data, v_min, v_max)
|
| 124 |
+
print(f"Clamped data to original range: min={v_min.item():.4f}, max={v_max.item():.4f}")
|
| 125 |
+
|
| 126 |
+
output = data.cpu().numpy() if input_was_numpy else data
|
| 127 |
+
print(f"Output shape: {output.shape}, type: {type(output)}")
|
| 128 |
+
return output
|
| 129 |
+
|
| 130 |
+
def _chunked_resample(
|
| 131 |
+
volume: torch.Tensor,
|
| 132 |
+
target_shape: Tuple[int, ...],
|
| 133 |
+
mode: str,
|
| 134 |
+
chunk_size: int,
|
| 135 |
+
device: torch.device,
|
| 136 |
+
is_seg: bool
|
| 137 |
+
) -> torch.Tensor:
|
| 138 |
+
"""Chunked resampling for large volumes with adaptive chunk sizing."""
|
| 139 |
+
print(f"\nStarting _chunked_resample with input shape: {volume.shape}, target shape: {target_shape}")
|
| 140 |
+
C, D, H, W = volume.shape
|
| 141 |
+
tD, tH, tW = target_shape
|
| 142 |
+
|
| 143 |
+
# Adaptive chunk size based on available memory
|
| 144 |
+
if device.type == 'cpu':
|
| 145 |
+
available_memory = psutil.virtual_memory().available / 1024**2 # in MB
|
| 146 |
+
else:
|
| 147 |
+
total_memory = torch.cuda.get_device_properties(device).total_memory / 1024**2 # in MB
|
| 148 |
+
allocated_memory = torch.cuda.memory_allocated(device) / 1024**2
|
| 149 |
+
available_memory = total_memory - allocated_memory
|
| 150 |
+
|
| 151 |
+
mem_per_voxel = volume.element_size() * volume.nelement() / volume.numel()
|
| 152 |
+
target_voxel_count = C * tD * tH * tW
|
| 153 |
+
chunk_mem_ratio = 0.5 if device.type == 'cpu' else 0.3
|
| 154 |
+
adaptive_chunk_size = max(
|
| 155 |
+
32,
|
| 156 |
+
min(chunk_size, int((available_memory * chunk_mem_ratio / mem_per_voxel / C) ** (1/3)))
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
# Early return for small volumes
|
| 160 |
+
if D * H * W <= 128**3:
|
| 161 |
+
with torch.cuda.amp.autocast(enabled=not is_seg):
|
| 162 |
+
start_time = time.time()
|
| 163 |
+
# Cast to float for interpolation if is_seg and mode is nearest
|
| 164 |
+
input_tensor = volume.float() if is_seg and mode == 'nearest' else volume
|
| 165 |
+
result = F.interpolate(
|
| 166 |
+
input_tensor.unsqueeze(0),
|
| 167 |
+
size=target_shape,
|
| 168 |
+
mode=mode,
|
| 169 |
+
align_corners=False if mode != 'nearest' else None
|
| 170 |
+
).squeeze(0)
|
| 171 |
+
# Convert back to original dtype for segmentation
|
| 172 |
+
if is_seg:
|
| 173 |
+
result = result.round().to(volume.dtype)
|
| 174 |
+
# print(f"Direct interpolation completed in {time.time() - start_time:.3f}s, output shape: {result.shape}")
|
| 175 |
+
return result
|
| 176 |
+
|
| 177 |
+
result = torch.zeros((C, tD, tH, tW), device=device, dtype=volume.dtype)
|
| 178 |
+
|
| 179 |
+
out_chunk_size = max(1, int(adaptive_chunk_size * min(tD/D, tH/H, tW/W)))
|
| 180 |
+
|
| 181 |
+
for c in range(C):
|
| 182 |
+
for z in range(0, tD, out_chunk_size):
|
| 183 |
+
z_end = min(z + out_chunk_size, tD)
|
| 184 |
+
for y in range(0, tH, out_chunk_size):
|
| 185 |
+
y_end = min(y + out_chunk_size, tH)
|
| 186 |
+
for x in range(0, tW, out_chunk_size):
|
| 187 |
+
x_end = min(x + out_chunk_size, tW)
|
| 188 |
+
|
| 189 |
+
in_z = max(0, int(z * D / tD) - 1)
|
| 190 |
+
in_z_end = min(D, int(z_end * D / tD) + 2)
|
| 191 |
+
in_y = max(0, int(y * H / tH) - 1)
|
| 192 |
+
in_y_end = min(H, int(y_end * H / tH) + 2)
|
| 193 |
+
in_x = max(0, int(x * W / tW) - 1)
|
| 194 |
+
in_x_end = min(W, int(x_end * W / tW) + 2)
|
| 195 |
+
|
| 196 |
+
chunk = volume[c:c+1, in_z:in_z_end, in_y:in_y_end, in_x:in_x_end]
|
| 197 |
+
chunk_target = (z_end - z, y_end - y, x_end - x)
|
| 198 |
+
|
| 199 |
+
with torch.cuda.amp.autocast(enabled=not is_seg):
|
| 200 |
+
start_time = time.time()
|
| 201 |
+
# Cast to float for interpolation if is_seg and mode is nearest
|
| 202 |
+
input_chunk = chunk.float() if is_seg and mode == 'nearest' else chunk
|
| 203 |
+
resampled_chunk = F.interpolate(
|
| 204 |
+
input_chunk.unsqueeze(0),
|
| 205 |
+
size=chunk_target,
|
| 206 |
+
mode=mode,
|
| 207 |
+
align_corners=False if mode != 'nearest' else None
|
| 208 |
+
).squeeze(0)
|
| 209 |
+
# Convert back to original dtype for segmentation
|
| 210 |
+
if is_seg:
|
| 211 |
+
resampled_chunk = resampled_chunk.round().to(volume.dtype)
|
| 212 |
+
# print(f"Chunk interpolation completed in {time.time() - start_time:.3f}s, shape: {resampled_chunk.shape}")
|
| 213 |
+
|
| 214 |
+
result[c, z:z_end, y:y_end, x:x_end] = resampled_chunk
|
| 215 |
+
del chunk, resampled_chunk
|
| 216 |
+
if device.type == 'cuda':
|
| 217 |
+
torch.cuda.empty_cache()
|
| 218 |
+
|
| 219 |
+
return result
|
| 220 |
+
|
| 221 |
+
def resample_torch_simple(
|
| 222 |
+
data: Union[torch.Tensor, np.ndarray],
|
| 223 |
+
new_shape: Union[Tuple[int, ...], List[int], np.ndarray],
|
| 224 |
+
is_seg: bool = False,
|
| 225 |
+
num_threads: int = 4,
|
| 226 |
+
device: torch.device = torch.device('cpu'),
|
| 227 |
+
memefficient_seg_resampling: bool = False,
|
| 228 |
+
mode: str = 'linear'
|
| 229 |
+
) -> Union[torch.Tensor, np.ndarray]:
|
| 230 |
+
if mode == 'linear':
|
| 231 |
+
torch_mode = 'trilinear' if data.ndim == 4 else 'bilinear'
|
| 232 |
+
else:
|
| 233 |
+
torch_mode = mode
|
| 234 |
+
|
| 235 |
+
if isinstance(new_shape, np.ndarray):
|
| 236 |
+
new_shape = [int(i) for i in new_shape]
|
| 237 |
+
|
| 238 |
+
if all([i == j for i, j in zip(new_shape, data.shape[1:])]):
|
| 239 |
+
return data
|
| 240 |
+
|
| 241 |
+
n_threads = torch.get_num_threads()
|
| 242 |
+
torch.set_num_threads(num_threads)
|
| 243 |
+
new_shape = tuple(new_shape)
|
| 244 |
+
with torch.no_grad():
|
| 245 |
+
input_was_numpy = isinstance(data, np.ndarray)
|
| 246 |
+
if input_was_numpy:
|
| 247 |
+
data = torch.from_numpy(data).to(device)
|
| 248 |
+
else:
|
| 249 |
+
orig_device = deepcopy(data.device)
|
| 250 |
+
data = data.to(device)
|
| 251 |
+
|
| 252 |
+
if is_seg:
|
| 253 |
+
unique_values = torch.unique(data)
|
| 254 |
+
result_dtype = torch.int8 if max(unique_values) < 127 else torch.int16
|
| 255 |
+
result = torch.zeros((data.shape[0], *new_shape), dtype=result_dtype, device=device)
|
| 256 |
+
if not memefficient_seg_resampling:
|
| 257 |
+
result_tmp = torch.zeros((len(unique_values), data.shape[0], *new_shape), dtype=torch.float16,
|
| 258 |
+
device=device)
|
| 259 |
+
scale_factor = 1000
|
| 260 |
+
done_mask = torch.zeros_like(result, dtype=torch.bool, device=device)
|
| 261 |
+
for i, u in enumerate(unique_values):
|
| 262 |
+
result_tmp[i] = F.interpolate((data[None] == u).float() * scale_factor, new_shape, mode=torch_mode,
|
| 263 |
+
antialias=False)[0]
|
| 264 |
+
mask = result_tmp[i] > (0.7 * scale_factor)
|
| 265 |
+
result[mask] = u.item()
|
| 266 |
+
done_mask |= mask
|
| 267 |
+
if not torch.all(done_mask):
|
| 268 |
+
result[~done_mask] = unique_values[result_tmp[:, ~done_mask].argmax(0)].to(result_dtype)
|
| 269 |
+
else:
|
| 270 |
+
for i, u in enumerate(unique_values):
|
| 271 |
+
if u == 0:
|
| 272 |
+
continue
|
| 273 |
+
result[F.interpolate((data[None] == u).float(), new_shape, mode=torch_mode, antialias=False)[0] > 0.5] = u
|
| 274 |
+
else:
|
| 275 |
+
result = F.interpolate(data[None].float(), new_shape, mode=torch_mode, antialias=False)[0]
|
| 276 |
+
|
| 277 |
+
if input_was_numpy:
|
| 278 |
+
result = result.cpu().numpy()
|
| 279 |
+
else:
|
| 280 |
+
result = result.to(orig_device)
|
| 281 |
+
|
| 282 |
+
torch.set_num_threads(n_threads)
|
| 283 |
+
return result
|
| 284 |
+
|
| 285 |
+
def resample_torch_fornnunet(
|
| 286 |
+
data: Union[torch.Tensor, np.ndarray],
|
| 287 |
+
new_shape: Union[Tuple[int, ...], List[int], np.ndarray],
|
| 288 |
+
current_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
|
| 289 |
+
new_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
|
| 290 |
+
is_seg: bool = False,
|
| 291 |
+
num_threads: int = 4,
|
| 292 |
+
device: torch.device = torch.device('cpu'),
|
| 293 |
+
memefficient_seg_resampling: bool = False,
|
| 294 |
+
force_separate_z: Union[bool, None] = None,
|
| 295 |
+
separate_z_anisotropy_threshold: float = ANISO_THRESHOLD,
|
| 296 |
+
mode: str = 'linear',
|
| 297 |
+
aniso_axis_mode: str = 'nearest-exact'
|
| 298 |
+
) -> Union[torch.Tensor, np.ndarray]:
|
| 299 |
+
assert data.ndim == 4, "data must be c, x, y, z"
|
| 300 |
+
new_shape = [int(i) for i in new_shape]
|
| 301 |
+
orig_shape = data.shape
|
| 302 |
+
|
| 303 |
+
do_separate_z, axis = determine_do_sep_z_and_axis(force_separate_z, current_spacing, new_spacing,
|
| 304 |
+
separate_z_anisotropy_threshold)
|
| 305 |
+
|
| 306 |
+
if do_separate_z:
|
| 307 |
+
was_numpy = isinstance(data, np.ndarray)
|
| 308 |
+
if was_numpy:
|
| 309 |
+
data = torch.from_numpy(data)
|
| 310 |
+
|
| 311 |
+
if isinstance(axis, list):
|
| 312 |
+
axis = axis[0]
|
| 313 |
+
|
| 314 |
+
tmp = "xyz"
|
| 315 |
+
axis_letter = tmp[axis]
|
| 316 |
+
others_int = [i for i in range(3) if i != axis]
|
| 317 |
+
others = [tmp[i] for i in others_int]
|
| 318 |
+
|
| 319 |
+
data = rearrange(data, f"c x y z -> (c {axis_letter}) {others[0]} {others[1]}")
|
| 320 |
+
tmp_new_shape = [new_shape[i] for i in others_int]
|
| 321 |
+
data = resample_torch_simple(data, tmp_new_shape, is_seg=is_seg, num_threads=num_threads, device=device,
|
| 322 |
+
memefficient_seg_resampling=memefficient_seg_resampling, mode=mode)
|
| 323 |
+
data = rearrange(data, f"(c {axis_letter}) {others[0]} {others[1]} -> c x y z",
|
| 324 |
+
**{axis_letter: orig_shape[axis + 1], others[0]: tmp_new_shape[0], others[1]: tmp_new_shape[1]})
|
| 325 |
+
data = resample_torch_simple(data, new_shape, is_seg=is_seg, num_threads=num_threads, device=device,
|
| 326 |
+
memefficient_seg_resampling=memefficient_seg_resampling, mode=aniso_axis_mode)
|
| 327 |
+
if was_numpy:
|
| 328 |
+
data = data.numpy()
|
| 329 |
+
return data
|
| 330 |
+
else:
|
| 331 |
+
return resample_torch_simple(data, new_shape, is_seg, num_threads, device, memefficient_seg_resampling)
|
| 332 |
+
|
| 333 |
+
def dice_score(pred: np.ndarray, true: np.ndarray) -> float:
|
| 334 |
+
"""Compute Dice score for segmentation masks."""
|
| 335 |
+
pred = pred.flatten()
|
| 336 |
+
true = true.flatten()
|
| 337 |
+
intersection = np.sum(pred * true)
|
| 338 |
+
return (2. * intersection) / (np.sum(pred) + np.sum(true) + 1e-8)
|
| 339 |
+
|
| 340 |
+
# Placeholder for compute_new_shape if not provided
|
| 341 |
+
def compute_new_shape(original_shape, current_spacing, target_spacing):
|
| 342 |
+
"""
|
| 343 |
+
Compute the new shape based on the spacing ratio.
|
| 344 |
+
original_shape: (z, y, x)
|
| 345 |
+
current_spacing: (z, y, x)
|
| 346 |
+
target_spacing: (z, y, x)
|
| 347 |
+
"""
|
| 348 |
+
zoom_factors = [c / t for c, t in zip(current_spacing, target_spacing)]
|
| 349 |
+
new_shape = [int(round(s * z)) for s, z in zip(original_shape, zoom_factors)]
|
| 350 |
+
return tuple(new_shape)
|
| 351 |
+
|
| 352 |
+
# Function to save as NIfTI
|
| 353 |
+
def save_nii(array, spacing, output_path, is_seg=False):
|
| 354 |
+
"""
|
| 355 |
+
Save numpy array as NIfTI file with specified spacing.
|
| 356 |
+
is_seg: If True, convert to int32 for segmentation masks.
|
| 357 |
+
"""
|
| 358 |
+
# Convert torch tensor to numpy if necessary
|
| 359 |
+
if isinstance(array, torch.Tensor):
|
| 360 |
+
array = array.cpu().numpy()
|
| 361 |
+
|
| 362 |
+
# Convert data type for NIfTI compatibility
|
| 363 |
+
if is_seg:
|
| 364 |
+
array = array.astype(np.int32) # Convert segmentation to int32
|
| 365 |
+
else:
|
| 366 |
+
array = array.astype(np.float32) # Ensure image is float32
|
| 367 |
+
|
| 368 |
+
# Transpose to (X, Y, Z, C) for NIfTI
|
| 369 |
+
if array.ndim == 4:
|
| 370 |
+
array = array.transpose(2, 3, 1, 0) # From (C, Z, Y, X) to (X, Y, Z, C)
|
| 371 |
+
else:
|
| 372 |
+
array = array.transpose(2, 3, 1) # From (Z, Y, X) to (X, Y, Z)
|
| 373 |
+
|
| 374 |
+
# Create NIfTI image with affine based on spacing
|
| 375 |
+
affine = np.diag(list(spacing) + [1.0])
|
| 376 |
+
nii_img = nib.Nifti1Image(array, affine=affine)
|
| 377 |
+
nib.save(nii_img, output_path)
|
| 378 |
+
print(f"Saved: {output_path}")
|
| 379 |
+
|
| 380 |
+
# Main resampling function
|
| 381 |
+
def main():
|
| 382 |
+
torch.set_num_threads(4)
|
| 383 |
+
device = torch.device('cuda') #torch.device('cpu') # Force CPU as per provided code
|
| 384 |
+
print(f"\nRunning tests on device: {device}")
|
| 385 |
+
|
| 386 |
+
# Define paths
|
| 387 |
+
npz_file_path = "/media/shipc/hhd_8T/spc/code/CVPR2025_Text_guided_seg_submission/inputs/Microscopy_cremi_000_sc.npz"
|
| 388 |
+
gt_path = "/media/shipc/hhd_8T/spc/code/CVPR2025_Text_guided_seg_submission/gts/Microscopy_cremi_000_sc.npz"
|
| 389 |
+
output_dir = "/media/shipc/hhd_8T/spc/code/CVPR2025_Text_guided_seg_submission/workspace_teamx/outputs_test_resample"
|
| 390 |
+
|
| 391 |
+
# Ensure output directory exists
|
| 392 |
+
if not os.path.exists(output_dir):
|
| 393 |
+
os.makedirs(output_dir)
|
| 394 |
+
|
| 395 |
+
# Load input data
|
| 396 |
+
data = np.load(npz_file_path, allow_pickle=True)
|
| 397 |
+
img_array = data['imgs'] # Shape: (C, Z, Y, X) or (Z, Y, X)
|
| 398 |
+
img_spacing = data['spacing'] # (z, y, x)
|
| 399 |
+
img_spacing = [1.0, 1.0, 1.0] # Override as per provided code
|
| 400 |
+
gt_data = np.load(gt_path, allow_pickle=True)
|
| 401 |
+
gt_array = gt_data['gts'] # Shape: (C, Z, Y, X) or (Z, Y, X)
|
| 402 |
+
|
| 403 |
+
# Convert data types to PyTorch-compatible types
|
| 404 |
+
img_array = img_array.astype(np.float32) # Convert image to float32
|
| 405 |
+
gt_array = gt_array.astype(np.int32) # Convert segmentation mask to int32
|
| 406 |
+
|
| 407 |
+
# Ensure img_array and gt_array have channel dimension
|
| 408 |
+
if img_array.ndim == 3:
|
| 409 |
+
img_array = img_array[np.newaxis, ...] # Add channel dimension: (1, Z, Y, X)
|
| 410 |
+
if gt_array.ndim == 3:
|
| 411 |
+
gt_array = gt_array[np.newaxis, ...] # Add channel dimension: (1, Z, Y, X)
|
| 412 |
+
|
| 413 |
+
# Define target spacings to test
|
| 414 |
+
target_spacings = [
|
| 415 |
+
(1.2, 1.2, 1.2),
|
| 416 |
+
(1.5, 1.5, 1.5),
|
| 417 |
+
(2.0, 2.0, 2.0),
|
| 418 |
+
]
|
| 419 |
+
|
| 420 |
+
# Original shape and spacing
|
| 421 |
+
original_shape = img_array.shape[1:] # (Z, Y, X)
|
| 422 |
+
current_spacing = img_spacing
|
| 423 |
+
print(f"\nOriginal image shape: {original_shape}, Current spacing (z,y,x): {current_spacing}")
|
| 424 |
+
|
| 425 |
+
for target_spacing in target_spacings:
|
| 426 |
+
print(f"\n=== Resampling to Target Spacing: {target_spacing} ===")
|
| 427 |
+
|
| 428 |
+
# Compute new shape
|
| 429 |
+
new_shape = compute_new_shape(original_shape, current_spacing, target_spacing)
|
| 430 |
+
print(f"Computed target shape: {new_shape}")
|
| 431 |
+
|
| 432 |
+
# === Image Resampling ===
|
| 433 |
+
print("\nResampling image...")
|
| 434 |
+
|
| 435 |
+
# Ground truth resampling
|
| 436 |
+
print("Computing ground truth with resample_torch_simple...")
|
| 437 |
+
start_time = time.time()
|
| 438 |
+
if device.type == 'cuda':
|
| 439 |
+
torch.cuda.synchronize() # Ensure GPU operations are complete
|
| 440 |
+
gt_img = resample_torch_simple(
|
| 441 |
+
img_array,
|
| 442 |
+
new_shape=new_shape,
|
| 443 |
+
is_seg=False,
|
| 444 |
+
num_threads=4,
|
| 445 |
+
device=device
|
| 446 |
+
)
|
| 447 |
+
if device.type == 'cuda':
|
| 448 |
+
torch.cuda.synchronize() # Ensure GPU operations are complete
|
| 449 |
+
gt_time = time.time() - start_time
|
| 450 |
+
output_path = os.path.join(output_dir, f"img_gt_spacing_{target_spacing[0]}_{target_spacing[1]}_{target_spacing[2]}.nii.gz")
|
| 451 |
+
print(f"Ground truth image shape: {gt_img.shape}, Time: {gt_time:.3f}s")
|
| 452 |
+
save_nii(gt_img, target_spacing, output_path, is_seg=False)
|
| 453 |
+
|
| 454 |
+
# Optimized resampling
|
| 455 |
+
print("Running optimized_3d_resample...")
|
| 456 |
+
start_time = time.time()
|
| 457 |
+
if device.type == 'cuda':
|
| 458 |
+
torch.cuda.synchronize()
|
| 459 |
+
mem_before = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2
|
| 460 |
+
resampled_img_opt = optimized_3d_resample(
|
| 461 |
+
img_array,
|
| 462 |
+
current_spacing,
|
| 463 |
+
target_spacing,
|
| 464 |
+
is_seg=False,
|
| 465 |
+
device=device,
|
| 466 |
+
num_threads=4,
|
| 467 |
+
chunk_size=64
|
| 468 |
+
)
|
| 469 |
+
if device.type == 'cuda':
|
| 470 |
+
torch.cuda.synchronize()
|
| 471 |
+
|
| 472 |
+
opt_time = time.time() - start_time
|
| 473 |
+
mem_after = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2
|
| 474 |
+
opt_mae = np.mean(np.abs(resampled_img_opt - gt_img))
|
| 475 |
+
output_path = os.path.join(output_dir, f"img_opt_spacing_{target_spacing[0]}_{target_spacing[1]}_{target_spacing[2]}.nii.gz")
|
| 476 |
+
print(f"Optimized image shape: {resampled_img_opt.shape}, Time: {opt_time:.3f}s, "
|
| 477 |
+
f"Memory used: {mem_after - mem_before:.2f} MB, MAE: {opt_mae:.6f}")
|
| 478 |
+
save_nii(resampled_img_opt, target_spacing, output_path, is_seg=False)
|
| 479 |
+
|
| 480 |
+
# Original resampling
|
| 481 |
+
print("Running resample_torch_fornnunet...")
|
| 482 |
+
start_time = time.time()
|
| 483 |
+
if device.type == 'cuda':
|
| 484 |
+
torch.cuda.synchronize()
|
| 485 |
+
mem_before = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2
|
| 486 |
+
resampled_img_orig = resample_torch_fornnunet(
|
| 487 |
+
img_array,
|
| 488 |
+
new_shape,
|
| 489 |
+
current_spacing,
|
| 490 |
+
target_spacing,
|
| 491 |
+
is_seg=False,
|
| 492 |
+
num_threads=4,
|
| 493 |
+
device=device
|
| 494 |
+
)
|
| 495 |
+
if device.type == 'cuda':
|
| 496 |
+
torch.cuda.synchronize()
|
| 497 |
+
orig_time = time.time() - start_time
|
| 498 |
+
mem_after = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2
|
| 499 |
+
orig_mae = np.mean(np.abs(resampled_img_orig - gt_img))
|
| 500 |
+
output_path = os.path.join(output_dir, f"img_orig_spacing_{target_spacing[0]}_{target_spacing[1]}_{target_spacing[2]}.nii.gz")
|
| 501 |
+
print(f"Original image shape: {resampled_img_orig.shape}, Time: {orig_time:.3f}s, "
|
| 502 |
+
f"Memory used: {mem_after - mem_before:.2f} MB, MAE: {orig_mae:.6f}")
|
| 503 |
+
save_nii(resampled_img_orig, target_spacing, output_path, is_seg=False)
|
| 504 |
+
|
| 505 |
+
# === Segmentation Mask Resampling ===
|
| 506 |
+
print("\nResampling segmentation mask...")
|
| 507 |
+
|
| 508 |
+
# Ground truth resampling
|
| 509 |
+
print("Computing ground truth with resample_torch_simple...")
|
| 510 |
+
start_time = time.time()
|
| 511 |
+
if device.type == 'cuda':
|
| 512 |
+
torch.cuda.synchronize()
|
| 513 |
+
gt_seg = resample_torch_simple(
|
| 514 |
+
gt_array,
|
| 515 |
+
new_shape=new_shape,
|
| 516 |
+
is_seg=True,
|
| 517 |
+
num_threads=4,
|
| 518 |
+
device=device
|
| 519 |
+
)
|
| 520 |
+
if device.type == 'cuda':
|
| 521 |
+
torch.cuda.synchronize()
|
| 522 |
+
gt_seg_time = time.time() - start_time
|
| 523 |
+
output_path = os.path.join(output_dir, f"seg_gt_spacing_{target_spacing[0]}_{target_spacing[1]}_{target_spacing[2]}.nii.gz")
|
| 524 |
+
print(f"Ground truth segmentation shape: {gt_seg.shape}, Time: {gt_seg_time:.3f}s")
|
| 525 |
+
save_nii(gt_seg, target_spacing, output_path, is_seg=True)
|
| 526 |
+
|
| 527 |
+
# Optimized resampling
|
| 528 |
+
print("Running optimized_3d_resample for segmentation...")
|
| 529 |
+
start_time = time.time()
|
| 530 |
+
if device.type == 'cuda':
|
| 531 |
+
torch.cuda.synchronize()
|
| 532 |
+
mem_before = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2
|
| 533 |
+
resampled_seg_opt = optimized_3d_resample(
|
| 534 |
+
gt_array,
|
| 535 |
+
current_spacing,
|
| 536 |
+
target_spacing,
|
| 537 |
+
is_seg=True,
|
| 538 |
+
device=device,
|
| 539 |
+
num_threads=4,
|
| 540 |
+
chunk_size=64
|
| 541 |
+
)
|
| 542 |
+
if device.type == 'cuda':
|
| 543 |
+
torch.cuda.synchronize()
|
| 544 |
+
|
| 545 |
+
opt_seg_time = time.time() - start_time
|
| 546 |
+
mem_after = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2
|
| 547 |
+
opt_dice = dice_score(resampled_seg_opt, gt_seg)
|
| 548 |
+
output_path = os.path.join(output_dir, f"seg_opt_spacing_{target_spacing[0]}_{target_spacing[1]}_{target_spacing[2]}.nii.gz")
|
| 549 |
+
print(f"Optimized segmentation shape: {resampled_seg_opt.shape}, Time: {opt_seg_time:.3f}s, "
|
| 550 |
+
f"Memory used: {mem_after - mem_before:.2f} MB, Dice: {opt_dice:.6f}")
|
| 551 |
+
save_nii(resampled_seg_opt, target_spacing, output_path, is_seg=True)
|
| 552 |
+
|
| 553 |
+
# Original resampling
|
| 554 |
+
print("Running resample_torch_fornnunet for segmentation...")
|
| 555 |
+
start_time = time.time()
|
| 556 |
+
if device.type == 'cuda':
|
| 557 |
+
torch.cuda.synchronize()
|
| 558 |
+
mem_before = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2
|
| 559 |
+
resampled_seg_orig = resample_torch_fornnunet(
|
| 560 |
+
gt_array,
|
| 561 |
+
new_shape,
|
| 562 |
+
current_spacing,
|
| 563 |
+
target_spacing,
|
| 564 |
+
is_seg=True,
|
| 565 |
+
num_threads=4,
|
| 566 |
+
device=device
|
| 567 |
+
)
|
| 568 |
+
if device.type == 'cuda':
|
| 569 |
+
torch.cuda.synchronize()
|
| 570 |
+
|
| 571 |
+
orig_seg_time = time.time() - start_time
|
| 572 |
+
mem_after = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2
|
| 573 |
+
orig_dice = dice_score(resampled_seg_orig, gt_seg)
|
| 574 |
+
output_path = os.path.join(output_dir, f"seg_orig_spacing_{target_spacing[0]}_{target_spacing[1]}_{target_spacing[2]}.nii.gz")
|
| 575 |
+
print(f"Original segmentation shape: {resampled_seg_orig.shape}, Time: {orig_seg_time:.3f}s, "
|
| 576 |
+
f"Memory used: {mem_after - mem_before:.2f} MB, Dice: {orig_dice:.6f}")
|
| 577 |
+
save_nii(resampled_seg_orig, target_spacing, output_path, is_seg=True)
|
| 578 |
+
|
| 579 |
+
# Summary
|
| 580 |
+
print(f"\n=== Summary for Target Spacing: {target_spacing} ===")
|
| 581 |
+
print("Image Resampling Metrics:")
|
| 582 |
+
print(f"Optimized - Shape: {resampled_img_opt.shape}, Time: {opt_time:.3f}s, MAE: {opt_mae:.6f}")
|
| 583 |
+
print(f"Original - Shape: {resampled_img_orig.shape}, Time: {orig_time:.3f}s, MAE: {orig_mae:.6f}")
|
| 584 |
+
print(f"Time Improvement: {(orig_time - opt_time) / orig_time * 100:.2f}%")
|
| 585 |
+
print(f"MAE Improvement: {(orig_mae - opt_mae) / orig_mae * 100:.2f}%")
|
| 586 |
+
print("Segmentation Mask Resampling Metrics:")
|
| 587 |
+
print(f"Optimized - Shape: {resampled_seg_opt.shape}, Time: {opt_seg_time:.3f}s, Dice: {opt_dice:.6f}")
|
| 588 |
+
print(f"Original - Shape: {resampled_seg_orig.shape}, Time: {orig_seg_time:.3f}s, Dice: {orig_dice:.6f}")
|
| 589 |
+
print(f"Time Improvement: {(orig_seg_time - opt_seg_time) / orig_seg_time * 100:.2f}%")
|
| 590 |
+
print(f"Dice Improvement: {(opt_dice - orig_dice) / orig_dice * 100:.2f}%")
|
| 591 |
+
|
| 592 |
+
if __name__ == '__main__':
|
| 593 |
+
main()
|
environment.yml
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: medals_local_test
|
| 2 |
+
channels:
|
| 3 |
+
- pytorch
|
| 4 |
+
- nvidia
|
| 5 |
+
- defaults
|
| 6 |
+
dependencies:
|
| 7 |
+
- _libgcc_mutex=0.1=main
|
| 8 |
+
- _openmp_mutex=5.1=1_gnu
|
| 9 |
+
- aom=3.12.1=h7934f7d_0
|
| 10 |
+
- blas=1.0=mkl
|
| 11 |
+
- brotlicffi=1.2.0.0=py310h7354ed3_0
|
| 12 |
+
- bzip2=1.0.8=h5eee18b_6
|
| 13 |
+
- ca-certificates=2025.12.2=h06a4308_0
|
| 14 |
+
- cairo=1.18.4=h44eff21_0
|
| 15 |
+
- certifi=2025.11.12=py310h06a4308_0
|
| 16 |
+
- cffi=2.0.0=py310h4eded50_1
|
| 17 |
+
- charset-normalizer=3.4.4=py310h06a4308_0
|
| 18 |
+
- cuda-cudart=12.1.105=0
|
| 19 |
+
- cuda-cupti=12.1.105=0
|
| 20 |
+
- cuda-libraries=12.1.0=0
|
| 21 |
+
- cuda-nvrtc=12.1.105=0
|
| 22 |
+
- cuda-nvtx=12.1.105=0
|
| 23 |
+
- cuda-opencl=12.9.19=0
|
| 24 |
+
- cuda-runtime=12.1.0=0
|
| 25 |
+
- cuda-version=12.9=3
|
| 26 |
+
- dav1d=1.2.1=h5eee18b_0
|
| 27 |
+
- expat=2.7.3=h7354ed3_4
|
| 28 |
+
- ffmpeg=6.1.1=hecf7045_5
|
| 29 |
+
- filelock=3.20.0=py310h06a4308_0
|
| 30 |
+
- fontconfig=2.15.0=h2c49b7f_0
|
| 31 |
+
- freetype=2.13.3=h4a9f257_0
|
| 32 |
+
- fribidi=1.0.10=h7b6447c_0
|
| 33 |
+
- giflib=5.2.2=h5eee18b_0
|
| 34 |
+
- gmp=6.3.0=h6a678d5_0
|
| 35 |
+
- gmpy2=2.2.2=py310ha78e65c_0
|
| 36 |
+
- graphite2=1.3.14=h295c915_1
|
| 37 |
+
- harfbuzz=10.2.0=hdfddeaa_1
|
| 38 |
+
- icu=73.1=h6a678d5_0
|
| 39 |
+
- idna=3.11=py310h06a4308_0
|
| 40 |
+
- intel-openmp=2025.0.0=h06a4308_1171
|
| 41 |
+
- jinja2=3.1.6=py310h06a4308_0
|
| 42 |
+
- jpeg=9f=h5ce9db8_0
|
| 43 |
+
- lame=3.100=h7b6447c_0
|
| 44 |
+
- lcms2=2.17=heab6991_0
|
| 45 |
+
- ld_impl_linux-64=2.44=h153f514_2
|
| 46 |
+
- leptonica=1.82.0=hfdeec58_3
|
| 47 |
+
- lerc=4.0.0=h6a678d5_0
|
| 48 |
+
- libarchive=3.8.2=h3ec8f01_0
|
| 49 |
+
- libavif=1.3.0=h3539ee5_0
|
| 50 |
+
- libcublas=12.1.0.26=0
|
| 51 |
+
- libcufft=11.0.2.4=0
|
| 52 |
+
- libcufile=1.14.1.1=4
|
| 53 |
+
- libcurand=10.3.10.19=0
|
| 54 |
+
- libcusolver=11.4.4.55=0
|
| 55 |
+
- libcusparse=12.0.2.55=0
|
| 56 |
+
- libdeflate=1.22=h5eee18b_0
|
| 57 |
+
- libexpat=2.7.3=h7354ed3_4
|
| 58 |
+
- libffi=3.4.4=h6a678d5_1
|
| 59 |
+
- libgcc=15.2.0=h69a1729_7
|
| 60 |
+
- libgcc-ng=15.2.0=h166f726_7
|
| 61 |
+
- libglib=2.84.4=h77a78f3_0
|
| 62 |
+
- libgomp=15.2.0=h4751f2c_7
|
| 63 |
+
- libhwloc=2.12.1=default_hf1bbc79_1000
|
| 64 |
+
- libiconv=1.16=h5eee18b_3
|
| 65 |
+
- libjpeg-turbo=2.0.0=h9bf148f_0
|
| 66 |
+
- libnpp=12.0.2.50=0
|
| 67 |
+
- libnsl=2.0.0=h5eee18b_0
|
| 68 |
+
- libnvjitlink=12.1.105=0
|
| 69 |
+
- libnvjpeg=12.1.1.14=0
|
| 70 |
+
- libogg=1.3.5=h27cfd23_1
|
| 71 |
+
- libopenjpeg=2.5.4=hee96239_1
|
| 72 |
+
- libopus=1.3.1=h5eee18b_1
|
| 73 |
+
- libpng=1.6.50=h2ed474d_0
|
| 74 |
+
- libstdcxx=15.2.0=h39759b7_7
|
| 75 |
+
- libstdcxx-ng=15.2.0=hc03a8fd_7
|
| 76 |
+
- libtheora=1.2.0=h32ad74f_1
|
| 77 |
+
- libtiff=4.7.1=h029b1ac_0
|
| 78 |
+
- libuuid=1.41.5=h5eee18b_0
|
| 79 |
+
- libvorbis=1.3.7=h7b6447c_0
|
| 80 |
+
- libvpx=1.15.2=h4cb591d_0
|
| 81 |
+
- libwebp=1.6.0=h089d785_0
|
| 82 |
+
- libwebp-base=1.6.0=hb7bb969_0
|
| 83 |
+
- libxcb=1.17.0=h9b100fa_0
|
| 84 |
+
- libxml2=2.13.9=h2c43086_0
|
| 85 |
+
- libzlib=1.3.1=hb25bd0a_0
|
| 86 |
+
- llvm-openmp=14.0.6=h9e868ea_0
|
| 87 |
+
- lz4-c=1.9.4=h6a678d5_1
|
| 88 |
+
- markupsafe=3.0.2=py310h5eee18b_0
|
| 89 |
+
- mkl=2025.0.0=hacee8c2_941
|
| 90 |
+
- mkl-service=2.5.2=py310hacdc0fc_0
|
| 91 |
+
- mkl_fft=2.1.1=py310h8fe796d_0
|
| 92 |
+
- mkl_random=1.3.0=py310h505adc9_0
|
| 93 |
+
- mpc=1.3.1=h5eee18b_0
|
| 94 |
+
- mpfr=4.2.1=h5eee18b_0
|
| 95 |
+
- mpmath=1.3.0=py310h06a4308_0
|
| 96 |
+
- ncurses=6.5=h7934f7d_0
|
| 97 |
+
- networkx=3.4.2=py310h06a4308_0
|
| 98 |
+
- ocl-icd=2.3.3=h47b2149_0
|
| 99 |
+
- opencl-headers=2025.07.22=hfb20e49_0
|
| 100 |
+
- openh264=2.6.0=he621ea3_0
|
| 101 |
+
- openjpeg=2.5.4=h4e0627c_1
|
| 102 |
+
- openssl=3.0.18=hd6dcaed_0
|
| 103 |
+
- pcre2=10.46=hf426167_0
|
| 104 |
+
- pillow=12.0.0=py310h3b88751_1
|
| 105 |
+
- pip=25.3=pyhc872135_0
|
| 106 |
+
- pixman=0.46.4=h7934f7d_0
|
| 107 |
+
- pthread-stubs=0.3=h0ce48e5_1
|
| 108 |
+
- pycparser=2.23=py310h06a4308_0
|
| 109 |
+
- pysocks=1.7.1=py310h06a4308_1
|
| 110 |
+
- python=3.10.19=h6fa692b_0
|
| 111 |
+
- pytorch-cuda=12.1=ha16c6d3_6
|
| 112 |
+
- pytorch-mutex=1.0=cuda
|
| 113 |
+
- pyyaml=6.0.3=py310h591646f_0
|
| 114 |
+
- readline=8.3=hc2a1206_0
|
| 115 |
+
- requests=2.32.5=py310h06a4308_1
|
| 116 |
+
- setuptools=80.9.0=py310h06a4308_0
|
| 117 |
+
- sqlite=3.51.0=h2a70700_0
|
| 118 |
+
- sympy=1.14.0=py310h06a4308_1
|
| 119 |
+
- tbb=2022.3.0=h698db13_0
|
| 120 |
+
- tbb-devel=2022.3.0=h698db13_0
|
| 121 |
+
- tesseract=5.2.0=hb0d2e87_3
|
| 122 |
+
- tk=8.6.15=h54e0aa7_0
|
| 123 |
+
- typing_extensions=4.15.0=py310h06a4308_0
|
| 124 |
+
- urllib3=2.6.1=py310h06a4308_0
|
| 125 |
+
- wheel=0.45.1=py310h06a4308_0
|
| 126 |
+
- xorg-libx11=1.8.12=h9b100fa_1
|
| 127 |
+
- xorg-libxau=1.0.12=h9b100fa_0
|
| 128 |
+
- xorg-libxdmcp=1.1.5=h9b100fa_0
|
| 129 |
+
- xorg-libxext=1.3.6=h9b100fa_0
|
| 130 |
+
- xorg-libxrender=0.9.12=h9b100fa_0
|
| 131 |
+
- xorg-xorgproto=2024.1=h5eee18b_1
|
| 132 |
+
- xz=5.6.4=h5eee18b_1
|
| 133 |
+
- yaml=0.2.5=h7b6447c_0
|
| 134 |
+
- zlib=1.3.1=hb25bd0a_0
|
| 135 |
+
- zstd=1.5.7=h11fc155_0
|
| 136 |
+
- pip:
|
| 137 |
+
- acvl-utils==0.2.5
|
| 138 |
+
- argparse==1.4.0
|
| 139 |
+
- batchgenerators==0.25.1
|
| 140 |
+
- blosc2==3.12.2
|
| 141 |
+
- connected-components-3d==3.26.1
|
| 142 |
+
- contourpy==1.3.2
|
| 143 |
+
- cycler==0.12.1
|
| 144 |
+
- dicom2nifti==2.6.2
|
| 145 |
+
- dynamic-network-architectures==0.2
|
| 146 |
+
- einops==0.8.1
|
| 147 |
+
- fonttools==4.61.1
|
| 148 |
+
- fsspec==2025.12.0
|
| 149 |
+
- future==1.0.0
|
| 150 |
+
- hf-xet==1.2.0
|
| 151 |
+
- huggingface-hub==0.36.0
|
| 152 |
+
- imagecodecs==2025.3.30
|
| 153 |
+
- imageio==2.37.2
|
| 154 |
+
- importlib-resources==6.5.2
|
| 155 |
+
- joblib==1.5.3
|
| 156 |
+
- kiwisolver==1.4.9
|
| 157 |
+
- lazy-loader==0.4
|
| 158 |
+
- linecache2==1.0.0
|
| 159 |
+
- matplotlib==3.10.8
|
| 160 |
+
- monai==1.4.0
|
| 161 |
+
- msgpack==1.1.2
|
| 162 |
+
- ndindex==1.10.1
|
| 163 |
+
- nibabel==5.3.2
|
| 164 |
+
- nnunetv2==2.4.1
|
| 165 |
+
- numexpr==2.14.1
|
| 166 |
+
- numpy==1.26.4
|
| 167 |
+
- nvidia-cublas-cu12==12.1.3.1
|
| 168 |
+
- nvidia-cuda-cupti-cu12==12.1.105
|
| 169 |
+
- nvidia-cuda-nvrtc-cu12==12.1.105
|
| 170 |
+
- nvidia-cuda-runtime-cu12==12.1.105
|
| 171 |
+
- nvidia-cudnn-cu12==8.9.2.26
|
| 172 |
+
- nvidia-cufft-cu12==11.0.2.54
|
| 173 |
+
- nvidia-curand-cu12==10.3.2.106
|
| 174 |
+
- nvidia-cusolver-cu12==11.4.5.107
|
| 175 |
+
- nvidia-cusparse-cu12==12.1.0.106
|
| 176 |
+
- nvidia-nccl-cu12==2.19.3
|
| 177 |
+
- nvidia-nvjitlink-cu12==12.9.86
|
| 178 |
+
- nvidia-nvtx-cu12==12.1.105
|
| 179 |
+
- packaging==25.0
|
| 180 |
+
- pandas==2.3.3
|
| 181 |
+
- platformdirs==4.5.1
|
| 182 |
+
- positional-encodings==6.0.3
|
| 183 |
+
- py-cpuinfo==9.0.0
|
| 184 |
+
- pydicom==3.0.1
|
| 185 |
+
- pyparsing==3.2.5
|
| 186 |
+
- python-dateutil==2.9.0.post0
|
| 187 |
+
- python-gdcm==3.2.2
|
| 188 |
+
- python-graphviz==0.21
|
| 189 |
+
- pytz==2025.2
|
| 190 |
+
- regex==2025.11.3
|
| 191 |
+
- safetensors==0.7.0
|
| 192 |
+
- scikit-image==0.25.2
|
| 193 |
+
- scikit-learn==1.7.2
|
| 194 |
+
- scipy==1.15.3
|
| 195 |
+
- seaborn==0.13.2
|
| 196 |
+
- simpleitk==2.5.3
|
| 197 |
+
- six==1.17.0
|
| 198 |
+
- threadpoolctl==3.6.0
|
| 199 |
+
- tifffile==2025.5.10
|
| 200 |
+
- tokenizers==0.21.4
|
| 201 |
+
- torch==2.2.0+cu121
|
| 202 |
+
- torchaudio==2.2.0+cu121
|
| 203 |
+
- torchvision==0.17.0+cu121
|
| 204 |
+
- tqdm==4.67.1
|
| 205 |
+
- traceback2==1.4.0
|
| 206 |
+
- transformers==4.51.3
|
| 207 |
+
- triton==2.2.0
|
| 208 |
+
- tzdata==2025.3
|
| 209 |
+
- unittest2==1.1.0
|
| 210 |
+
- yacs==0.1.8
|
| 211 |
+
prefix: /yinghepool/shipengcheng/.conda/envs/medals_local_test
|
evaluate/SurfaceDice.py
ADDED
|
@@ -0,0 +1,492 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import scipy.ndimage
|
| 3 |
+
|
| 4 |
+
# neighbour_code_to_normals is a lookup table.
|
| 5 |
+
# For every binary neighbour code
|
| 6 |
+
# (2x2x2 neighbourhood = 8 neighbours = 8 bits = 256 codes)
|
| 7 |
+
# it contains the surface normals of the triangles (called "surfel" for
|
| 8 |
+
# "surface element" in the following). The length of the normal
|
| 9 |
+
# vector encodes the surfel area.
|
| 10 |
+
#
|
| 11 |
+
# created by compute_surface_area_lookup_table.ipynb using the
|
| 12 |
+
# marching_cube algorithm, see e.g. https://en.wikipedia.org/wiki/Marching_cubes
|
| 13 |
+
# credit to: http://medicaldecathlon.com/files/Surface_distance_based_measures.ipynb
|
| 14 |
+
neighbour_code_to_normals = [
|
| 15 |
+
[[0,0,0]],
|
| 16 |
+
[[0.125,0.125,0.125]],
|
| 17 |
+
[[-0.125,-0.125,0.125]],
|
| 18 |
+
[[-0.25,-0.25,0.0],[0.25,0.25,-0.0]],
|
| 19 |
+
[[0.125,-0.125,0.125]],
|
| 20 |
+
[[-0.25,-0.0,-0.25],[0.25,0.0,0.25]],
|
| 21 |
+
[[0.125,-0.125,0.125],[-0.125,-0.125,0.125]],
|
| 22 |
+
[[0.5,0.0,-0.0],[0.25,0.25,0.25],[0.125,0.125,0.125]],
|
| 23 |
+
[[-0.125,0.125,0.125]],
|
| 24 |
+
[[0.125,0.125,0.125],[-0.125,0.125,0.125]],
|
| 25 |
+
[[-0.25,0.0,0.25],[-0.25,0.0,0.25]],
|
| 26 |
+
[[0.5,0.0,0.0],[-0.25,-0.25,0.25],[-0.125,-0.125,0.125]],
|
| 27 |
+
[[0.25,-0.25,0.0],[0.25,-0.25,0.0]],
|
| 28 |
+
[[0.5,0.0,0.0],[0.25,-0.25,0.25],[-0.125,0.125,-0.125]],
|
| 29 |
+
[[-0.5,0.0,0.0],[-0.25,0.25,0.25],[-0.125,0.125,0.125]],
|
| 30 |
+
[[0.5,0.0,0.0],[0.5,0.0,0.0]],
|
| 31 |
+
[[0.125,-0.125,-0.125]],
|
| 32 |
+
[[0.0,-0.25,-0.25],[0.0,0.25,0.25]],
|
| 33 |
+
[[-0.125,-0.125,0.125],[0.125,-0.125,-0.125]],
|
| 34 |
+
[[0.0,-0.5,0.0],[0.25,0.25,0.25],[0.125,0.125,0.125]],
|
| 35 |
+
[[0.125,-0.125,0.125],[0.125,-0.125,-0.125]],
|
| 36 |
+
[[0.0,0.0,-0.5],[0.25,0.25,0.25],[-0.125,-0.125,-0.125]],
|
| 37 |
+
[[-0.125,-0.125,0.125],[0.125,-0.125,0.125],[0.125,-0.125,-0.125]],
|
| 38 |
+
[[-0.125,-0.125,-0.125],[-0.25,-0.25,-0.25],[0.25,0.25,0.25],[0.125,0.125,0.125]],
|
| 39 |
+
[[-0.125,0.125,0.125],[0.125,-0.125,-0.125]],
|
| 40 |
+
[[0.0,-0.25,-0.25],[0.0,0.25,0.25],[-0.125,0.125,0.125]],
|
| 41 |
+
[[-0.25,0.0,0.25],[-0.25,0.0,0.25],[0.125,-0.125,-0.125]],
|
| 42 |
+
[[0.125,0.125,0.125],[0.375,0.375,0.375],[0.0,-0.25,0.25],[-0.25,0.0,0.25]],
|
| 43 |
+
[[0.125,-0.125,-0.125],[0.25,-0.25,0.0],[0.25,-0.25,0.0]],
|
| 44 |
+
[[0.375,0.375,0.375],[0.0,0.25,-0.25],[-0.125,-0.125,-0.125],[-0.25,0.25,0.0]],
|
| 45 |
+
[[-0.5,0.0,0.0],[-0.125,-0.125,-0.125],[-0.25,-0.25,-0.25],[0.125,0.125,0.125]],
|
| 46 |
+
[[-0.5,0.0,0.0],[-0.125,-0.125,-0.125],[-0.25,-0.25,-0.25]],
|
| 47 |
+
[[0.125,-0.125,0.125]],
|
| 48 |
+
[[0.125,0.125,0.125],[0.125,-0.125,0.125]],
|
| 49 |
+
[[0.0,-0.25,0.25],[0.0,0.25,-0.25]],
|
| 50 |
+
[[0.0,-0.5,0.0],[0.125,0.125,-0.125],[0.25,0.25,-0.25]],
|
| 51 |
+
[[0.125,-0.125,0.125],[0.125,-0.125,0.125]],
|
| 52 |
+
[[0.125,-0.125,0.125],[-0.25,-0.0,-0.25],[0.25,0.0,0.25]],
|
| 53 |
+
[[0.0,-0.25,0.25],[0.0,0.25,-0.25],[0.125,-0.125,0.125]],
|
| 54 |
+
[[-0.375,-0.375,0.375],[-0.0,0.25,0.25],[0.125,0.125,-0.125],[-0.25,-0.0,-0.25]],
|
| 55 |
+
[[-0.125,0.125,0.125],[0.125,-0.125,0.125]],
|
| 56 |
+
[[0.125,0.125,0.125],[0.125,-0.125,0.125],[-0.125,0.125,0.125]],
|
| 57 |
+
[[-0.0,0.0,0.5],[-0.25,-0.25,0.25],[-0.125,-0.125,0.125]],
|
| 58 |
+
[[0.25,0.25,-0.25],[0.25,0.25,-0.25],[0.125,0.125,-0.125],[-0.125,-0.125,0.125]],
|
| 59 |
+
[[0.125,-0.125,0.125],[0.25,-0.25,0.0],[0.25,-0.25,0.0]],
|
| 60 |
+
[[0.5,0.0,0.0],[0.25,-0.25,0.25],[-0.125,0.125,-0.125],[0.125,-0.125,0.125]],
|
| 61 |
+
[[0.0,0.25,-0.25],[0.375,-0.375,-0.375],[-0.125,0.125,0.125],[0.25,0.25,0.0]],
|
| 62 |
+
[[-0.5,0.0,0.0],[-0.25,-0.25,0.25],[-0.125,-0.125,0.125]],
|
| 63 |
+
[[0.25,-0.25,0.0],[-0.25,0.25,0.0]],
|
| 64 |
+
[[0.0,0.5,0.0],[-0.25,0.25,0.25],[0.125,-0.125,-0.125]],
|
| 65 |
+
[[0.0,0.5,0.0],[0.125,-0.125,0.125],[-0.25,0.25,-0.25]],
|
| 66 |
+
[[0.0,0.5,0.0],[0.0,-0.5,0.0]],
|
| 67 |
+
[[0.25,-0.25,0.0],[-0.25,0.25,0.0],[0.125,-0.125,0.125]],
|
| 68 |
+
[[-0.375,-0.375,-0.375],[-0.25,0.0,0.25],[-0.125,-0.125,-0.125],[-0.25,0.25,0.0]],
|
| 69 |
+
[[0.125,0.125,0.125],[0.0,-0.5,0.0],[-0.25,-0.25,-0.25],[-0.125,-0.125,-0.125]],
|
| 70 |
+
[[0.0,-0.5,0.0],[-0.25,-0.25,-0.25],[-0.125,-0.125,-0.125]],
|
| 71 |
+
[[-0.125,0.125,0.125],[0.25,-0.25,0.0],[-0.25,0.25,0.0]],
|
| 72 |
+
[[0.0,0.5,0.0],[0.25,0.25,-0.25],[-0.125,-0.125,0.125],[-0.125,-0.125,0.125]],
|
| 73 |
+
[[-0.375,0.375,-0.375],[-0.25,-0.25,0.0],[-0.125,0.125,-0.125],[-0.25,0.0,0.25]],
|
| 74 |
+
[[0.0,0.5,0.0],[0.25,0.25,-0.25],[-0.125,-0.125,0.125]],
|
| 75 |
+
[[0.25,-0.25,0.0],[-0.25,0.25,0.0],[0.25,-0.25,0.0],[0.25,-0.25,0.0]],
|
| 76 |
+
[[-0.25,-0.25,0.0],[-0.25,-0.25,0.0],[-0.125,-0.125,0.125]],
|
| 77 |
+
[[0.125,0.125,0.125],[-0.25,-0.25,0.0],[-0.25,-0.25,0.0]],
|
| 78 |
+
[[-0.25,-0.25,0.0],[-0.25,-0.25,0.0]],
|
| 79 |
+
[[-0.125,-0.125,0.125]],
|
| 80 |
+
[[0.125,0.125,0.125],[-0.125,-0.125,0.125]],
|
| 81 |
+
[[-0.125,-0.125,0.125],[-0.125,-0.125,0.125]],
|
| 82 |
+
[[-0.125,-0.125,0.125],[-0.25,-0.25,0.0],[0.25,0.25,-0.0]],
|
| 83 |
+
[[0.0,-0.25,0.25],[0.0,-0.25,0.25]],
|
| 84 |
+
[[0.0,0.0,0.5],[0.25,-0.25,0.25],[0.125,-0.125,0.125]],
|
| 85 |
+
[[0.0,-0.25,0.25],[0.0,-0.25,0.25],[-0.125,-0.125,0.125]],
|
| 86 |
+
[[0.375,-0.375,0.375],[0.0,-0.25,-0.25],[-0.125,0.125,-0.125],[0.25,0.25,0.0]],
|
| 87 |
+
[[-0.125,-0.125,0.125],[-0.125,0.125,0.125]],
|
| 88 |
+
[[0.125,0.125,0.125],[-0.125,-0.125,0.125],[-0.125,0.125,0.125]],
|
| 89 |
+
[[-0.125,-0.125,0.125],[-0.25,0.0,0.25],[-0.25,0.0,0.25]],
|
| 90 |
+
[[0.5,0.0,0.0],[-0.25,-0.25,0.25],[-0.125,-0.125,0.125],[-0.125,-0.125,0.125]],
|
| 91 |
+
[[-0.0,0.5,0.0],[-0.25,0.25,-0.25],[0.125,-0.125,0.125]],
|
| 92 |
+
[[-0.25,0.25,-0.25],[-0.25,0.25,-0.25],[-0.125,0.125,-0.125],[-0.125,0.125,-0.125]],
|
| 93 |
+
[[-0.25,0.0,-0.25],[0.375,-0.375,-0.375],[0.0,0.25,-0.25],[-0.125,0.125,0.125]],
|
| 94 |
+
[[0.5,0.0,0.0],[-0.25,0.25,-0.25],[0.125,-0.125,0.125]],
|
| 95 |
+
[[-0.25,0.0,0.25],[0.25,0.0,-0.25]],
|
| 96 |
+
[[-0.0,0.0,0.5],[-0.25,0.25,0.25],[-0.125,0.125,0.125]],
|
| 97 |
+
[[-0.125,-0.125,0.125],[-0.25,0.0,0.25],[0.25,0.0,-0.25]],
|
| 98 |
+
[[-0.25,-0.0,-0.25],[-0.375,0.375,0.375],[-0.25,-0.25,0.0],[-0.125,0.125,0.125]],
|
| 99 |
+
[[0.0,0.0,-0.5],[0.25,0.25,-0.25],[-0.125,-0.125,0.125]],
|
| 100 |
+
[[-0.0,0.0,0.5],[0.0,0.0,0.5]],
|
| 101 |
+
[[0.125,0.125,0.125],[0.125,0.125,0.125],[0.25,0.25,0.25],[0.0,0.0,0.5]],
|
| 102 |
+
[[0.125,0.125,0.125],[0.25,0.25,0.25],[0.0,0.0,0.5]],
|
| 103 |
+
[[-0.25,0.0,0.25],[0.25,0.0,-0.25],[-0.125,0.125,0.125]],
|
| 104 |
+
[[-0.0,0.0,0.5],[0.25,-0.25,0.25],[0.125,-0.125,0.125],[0.125,-0.125,0.125]],
|
| 105 |
+
[[-0.25,0.0,0.25],[-0.25,0.0,0.25],[-0.25,0.0,0.25],[0.25,0.0,-0.25]],
|
| 106 |
+
[[0.125,-0.125,0.125],[0.25,0.0,0.25],[0.25,0.0,0.25]],
|
| 107 |
+
[[0.25,0.0,0.25],[-0.375,-0.375,0.375],[-0.25,0.25,0.0],[-0.125,-0.125,0.125]],
|
| 108 |
+
[[-0.0,0.0,0.5],[0.25,-0.25,0.25],[0.125,-0.125,0.125]],
|
| 109 |
+
[[0.125,0.125,0.125],[0.25,0.0,0.25],[0.25,0.0,0.25]],
|
| 110 |
+
[[0.25,0.0,0.25],[0.25,0.0,0.25]],
|
| 111 |
+
[[-0.125,-0.125,0.125],[0.125,-0.125,0.125]],
|
| 112 |
+
[[0.125,0.125,0.125],[-0.125,-0.125,0.125],[0.125,-0.125,0.125]],
|
| 113 |
+
[[-0.125,-0.125,0.125],[0.0,-0.25,0.25],[0.0,0.25,-0.25]],
|
| 114 |
+
[[0.0,-0.5,0.0],[0.125,0.125,-0.125],[0.25,0.25,-0.25],[-0.125,-0.125,0.125]],
|
| 115 |
+
[[0.0,-0.25,0.25],[0.0,-0.25,0.25],[0.125,-0.125,0.125]],
|
| 116 |
+
[[0.0,0.0,0.5],[0.25,-0.25,0.25],[0.125,-0.125,0.125],[0.125,-0.125,0.125]],
|
| 117 |
+
[[0.0,-0.25,0.25],[0.0,-0.25,0.25],[0.0,-0.25,0.25],[0.0,0.25,-0.25]],
|
| 118 |
+
[[0.0,0.25,0.25],[0.0,0.25,0.25],[0.125,-0.125,-0.125]],
|
| 119 |
+
[[-0.125,0.125,0.125],[0.125,-0.125,0.125],[-0.125,-0.125,0.125]],
|
| 120 |
+
[[-0.125,0.125,0.125],[0.125,-0.125,0.125],[-0.125,-0.125,0.125],[0.125,0.125,0.125]],
|
| 121 |
+
[[-0.0,0.0,0.5],[-0.25,-0.25,0.25],[-0.125,-0.125,0.125],[-0.125,-0.125,0.125]],
|
| 122 |
+
[[0.125,0.125,0.125],[0.125,-0.125,0.125],[0.125,-0.125,-0.125]],
|
| 123 |
+
[[-0.0,0.5,0.0],[-0.25,0.25,-0.25],[0.125,-0.125,0.125],[0.125,-0.125,0.125]],
|
| 124 |
+
[[0.125,0.125,0.125],[-0.125,-0.125,0.125],[0.125,-0.125,-0.125]],
|
| 125 |
+
[[0.0,-0.25,-0.25],[0.0,0.25,0.25],[0.125,0.125,0.125]],
|
| 126 |
+
[[0.125,0.125,0.125],[0.125,-0.125,-0.125]],
|
| 127 |
+
[[0.5,0.0,-0.0],[0.25,-0.25,-0.25],[0.125,-0.125,-0.125]],
|
| 128 |
+
[[-0.25,0.25,0.25],[-0.125,0.125,0.125],[-0.25,0.25,0.25],[0.125,-0.125,-0.125]],
|
| 129 |
+
[[0.375,-0.375,0.375],[0.0,0.25,0.25],[-0.125,0.125,-0.125],[-0.25,0.0,0.25]],
|
| 130 |
+
[[0.0,-0.5,0.0],[-0.25,0.25,0.25],[-0.125,0.125,0.125]],
|
| 131 |
+
[[-0.375,-0.375,0.375],[0.25,-0.25,0.0],[0.0,0.25,0.25],[-0.125,-0.125,0.125]],
|
| 132 |
+
[[-0.125,0.125,0.125],[-0.25,0.25,0.25],[0.0,0.0,0.5]],
|
| 133 |
+
[[0.125,0.125,0.125],[0.0,0.25,0.25],[0.0,0.25,0.25]],
|
| 134 |
+
[[0.0,0.25,0.25],[0.0,0.25,0.25]],
|
| 135 |
+
[[0.5,0.0,-0.0],[0.25,0.25,0.25],[0.125,0.125,0.125],[0.125,0.125,0.125]],
|
| 136 |
+
[[0.125,-0.125,0.125],[-0.125,-0.125,0.125],[0.125,0.125,0.125]],
|
| 137 |
+
[[-0.25,-0.0,-0.25],[0.25,0.0,0.25],[0.125,0.125,0.125]],
|
| 138 |
+
[[0.125,0.125,0.125],[0.125,-0.125,0.125]],
|
| 139 |
+
[[-0.25,-0.25,0.0],[0.25,0.25,-0.0],[0.125,0.125,0.125]],
|
| 140 |
+
[[0.125,0.125,0.125],[-0.125,-0.125,0.125]],
|
| 141 |
+
[[0.125,0.125,0.125],[0.125,0.125,0.125]],
|
| 142 |
+
[[0.125,0.125,0.125]],
|
| 143 |
+
[[0.125,0.125,0.125]],
|
| 144 |
+
[[0.125,0.125,0.125],[0.125,0.125,0.125]],
|
| 145 |
+
[[0.125,0.125,0.125],[-0.125,-0.125,0.125]],
|
| 146 |
+
[[-0.25,-0.25,0.0],[0.25,0.25,-0.0],[0.125,0.125,0.125]],
|
| 147 |
+
[[0.125,0.125,0.125],[0.125,-0.125,0.125]],
|
| 148 |
+
[[-0.25,-0.0,-0.25],[0.25,0.0,0.25],[0.125,0.125,0.125]],
|
| 149 |
+
[[0.125,-0.125,0.125],[-0.125,-0.125,0.125],[0.125,0.125,0.125]],
|
| 150 |
+
[[0.5,0.0,-0.0],[0.25,0.25,0.25],[0.125,0.125,0.125],[0.125,0.125,0.125]],
|
| 151 |
+
[[0.0,0.25,0.25],[0.0,0.25,0.25]],
|
| 152 |
+
[[0.125,0.125,0.125],[0.0,0.25,0.25],[0.0,0.25,0.25]],
|
| 153 |
+
[[-0.125,0.125,0.125],[-0.25,0.25,0.25],[0.0,0.0,0.5]],
|
| 154 |
+
[[-0.375,-0.375,0.375],[0.25,-0.25,0.0],[0.0,0.25,0.25],[-0.125,-0.125,0.125]],
|
| 155 |
+
[[0.0,-0.5,0.0],[-0.25,0.25,0.25],[-0.125,0.125,0.125]],
|
| 156 |
+
[[0.375,-0.375,0.375],[0.0,0.25,0.25],[-0.125,0.125,-0.125],[-0.25,0.0,0.25]],
|
| 157 |
+
[[-0.25,0.25,0.25],[-0.125,0.125,0.125],[-0.25,0.25,0.25],[0.125,-0.125,-0.125]],
|
| 158 |
+
[[0.5,0.0,-0.0],[0.25,-0.25,-0.25],[0.125,-0.125,-0.125]],
|
| 159 |
+
[[0.125,0.125,0.125],[0.125,-0.125,-0.125]],
|
| 160 |
+
[[0.0,-0.25,-0.25],[0.0,0.25,0.25],[0.125,0.125,0.125]],
|
| 161 |
+
[[0.125,0.125,0.125],[-0.125,-0.125,0.125],[0.125,-0.125,-0.125]],
|
| 162 |
+
[[-0.0,0.5,0.0],[-0.25,0.25,-0.25],[0.125,-0.125,0.125],[0.125,-0.125,0.125]],
|
| 163 |
+
[[0.125,0.125,0.125],[0.125,-0.125,0.125],[0.125,-0.125,-0.125]],
|
| 164 |
+
[[-0.0,0.0,0.5],[-0.25,-0.25,0.25],[-0.125,-0.125,0.125],[-0.125,-0.125,0.125]],
|
| 165 |
+
[[-0.125,0.125,0.125],[0.125,-0.125,0.125],[-0.125,-0.125,0.125],[0.125,0.125,0.125]],
|
| 166 |
+
[[-0.125,0.125,0.125],[0.125,-0.125,0.125],[-0.125,-0.125,0.125]],
|
| 167 |
+
[[0.0,0.25,0.25],[0.0,0.25,0.25],[0.125,-0.125,-0.125]],
|
| 168 |
+
[[0.0,-0.25,-0.25],[0.0,0.25,0.25],[0.0,0.25,0.25],[0.0,0.25,0.25]],
|
| 169 |
+
[[0.0,0.0,0.5],[0.25,-0.25,0.25],[0.125,-0.125,0.125],[0.125,-0.125,0.125]],
|
| 170 |
+
[[0.0,-0.25,0.25],[0.0,-0.25,0.25],[0.125,-0.125,0.125]],
|
| 171 |
+
[[0.0,-0.5,0.0],[0.125,0.125,-0.125],[0.25,0.25,-0.25],[-0.125,-0.125,0.125]],
|
| 172 |
+
[[-0.125,-0.125,0.125],[0.0,-0.25,0.25],[0.0,0.25,-0.25]],
|
| 173 |
+
[[0.125,0.125,0.125],[-0.125,-0.125,0.125],[0.125,-0.125,0.125]],
|
| 174 |
+
[[-0.125,-0.125,0.125],[0.125,-0.125,0.125]],
|
| 175 |
+
[[0.25,0.0,0.25],[0.25,0.0,0.25]],
|
| 176 |
+
[[0.125,0.125,0.125],[0.25,0.0,0.25],[0.25,0.0,0.25]],
|
| 177 |
+
[[-0.0,0.0,0.5],[0.25,-0.25,0.25],[0.125,-0.125,0.125]],
|
| 178 |
+
[[0.25,0.0,0.25],[-0.375,-0.375,0.375],[-0.25,0.25,0.0],[-0.125,-0.125,0.125]],
|
| 179 |
+
[[0.125,-0.125,0.125],[0.25,0.0,0.25],[0.25,0.0,0.25]],
|
| 180 |
+
[[-0.25,-0.0,-0.25],[0.25,0.0,0.25],[0.25,0.0,0.25],[0.25,0.0,0.25]],
|
| 181 |
+
[[-0.0,0.0,0.5],[0.25,-0.25,0.25],[0.125,-0.125,0.125],[0.125,-0.125,0.125]],
|
| 182 |
+
[[-0.25,0.0,0.25],[0.25,0.0,-0.25],[-0.125,0.125,0.125]],
|
| 183 |
+
[[0.125,0.125,0.125],[0.25,0.25,0.25],[0.0,0.0,0.5]],
|
| 184 |
+
[[0.125,0.125,0.125],[0.125,0.125,0.125],[0.25,0.25,0.25],[0.0,0.0,0.5]],
|
| 185 |
+
[[-0.0,0.0,0.5],[0.0,0.0,0.5]],
|
| 186 |
+
[[0.0,0.0,-0.5],[0.25,0.25,-0.25],[-0.125,-0.125,0.125]],
|
| 187 |
+
[[-0.25,-0.0,-0.25],[-0.375,0.375,0.375],[-0.25,-0.25,0.0],[-0.125,0.125,0.125]],
|
| 188 |
+
[[-0.125,-0.125,0.125],[-0.25,0.0,0.25],[0.25,0.0,-0.25]],
|
| 189 |
+
[[-0.0,0.0,0.5],[-0.25,0.25,0.25],[-0.125,0.125,0.125]],
|
| 190 |
+
[[-0.25,0.0,0.25],[0.25,0.0,-0.25]],
|
| 191 |
+
[[0.5,0.0,0.0],[-0.25,0.25,-0.25],[0.125,-0.125,0.125]],
|
| 192 |
+
[[-0.25,0.0,-0.25],[0.375,-0.375,-0.375],[0.0,0.25,-0.25],[-0.125,0.125,0.125]],
|
| 193 |
+
[[-0.25,0.25,-0.25],[-0.25,0.25,-0.25],[-0.125,0.125,-0.125],[-0.125,0.125,-0.125]],
|
| 194 |
+
[[-0.0,0.5,0.0],[-0.25,0.25,-0.25],[0.125,-0.125,0.125]],
|
| 195 |
+
[[0.5,0.0,0.0],[-0.25,-0.25,0.25],[-0.125,-0.125,0.125],[-0.125,-0.125,0.125]],
|
| 196 |
+
[[-0.125,-0.125,0.125],[-0.25,0.0,0.25],[-0.25,0.0,0.25]],
|
| 197 |
+
[[0.125,0.125,0.125],[-0.125,-0.125,0.125],[-0.125,0.125,0.125]],
|
| 198 |
+
[[-0.125,-0.125,0.125],[-0.125,0.125,0.125]],
|
| 199 |
+
[[0.375,-0.375,0.375],[0.0,-0.25,-0.25],[-0.125,0.125,-0.125],[0.25,0.25,0.0]],
|
| 200 |
+
[[0.0,-0.25,0.25],[0.0,-0.25,0.25],[-0.125,-0.125,0.125]],
|
| 201 |
+
[[0.0,0.0,0.5],[0.25,-0.25,0.25],[0.125,-0.125,0.125]],
|
| 202 |
+
[[0.0,-0.25,0.25],[0.0,-0.25,0.25]],
|
| 203 |
+
[[-0.125,-0.125,0.125],[-0.25,-0.25,0.0],[0.25,0.25,-0.0]],
|
| 204 |
+
[[-0.125,-0.125,0.125],[-0.125,-0.125,0.125]],
|
| 205 |
+
[[0.125,0.125,0.125],[-0.125,-0.125,0.125]],
|
| 206 |
+
[[-0.125,-0.125,0.125]],
|
| 207 |
+
[[-0.25,-0.25,0.0],[-0.25,-0.25,0.0]],
|
| 208 |
+
[[0.125,0.125,0.125],[-0.25,-0.25,0.0],[-0.25,-0.25,0.0]],
|
| 209 |
+
[[-0.25,-0.25,0.0],[-0.25,-0.25,0.0],[-0.125,-0.125,0.125]],
|
| 210 |
+
[[-0.25,-0.25,0.0],[-0.25,-0.25,0.0],[-0.25,-0.25,0.0],[0.25,0.25,-0.0]],
|
| 211 |
+
[[0.0,0.5,0.0],[0.25,0.25,-0.25],[-0.125,-0.125,0.125]],
|
| 212 |
+
[[-0.375,0.375,-0.375],[-0.25,-0.25,0.0],[-0.125,0.125,-0.125],[-0.25,0.0,0.25]],
|
| 213 |
+
[[0.0,0.5,0.0],[0.25,0.25,-0.25],[-0.125,-0.125,0.125],[-0.125,-0.125,0.125]],
|
| 214 |
+
[[-0.125,0.125,0.125],[0.25,-0.25,0.0],[-0.25,0.25,0.0]],
|
| 215 |
+
[[0.0,-0.5,0.0],[-0.25,-0.25,-0.25],[-0.125,-0.125,-0.125]],
|
| 216 |
+
[[0.125,0.125,0.125],[0.0,-0.5,0.0],[-0.25,-0.25,-0.25],[-0.125,-0.125,-0.125]],
|
| 217 |
+
[[-0.375,-0.375,-0.375],[-0.25,0.0,0.25],[-0.125,-0.125,-0.125],[-0.25,0.25,0.0]],
|
| 218 |
+
[[0.25,-0.25,0.0],[-0.25,0.25,0.0],[0.125,-0.125,0.125]],
|
| 219 |
+
[[0.0,0.5,0.0],[0.0,-0.5,0.0]],
|
| 220 |
+
[[0.0,0.5,0.0],[0.125,-0.125,0.125],[-0.25,0.25,-0.25]],
|
| 221 |
+
[[0.0,0.5,0.0],[-0.25,0.25,0.25],[0.125,-0.125,-0.125]],
|
| 222 |
+
[[0.25,-0.25,0.0],[-0.25,0.25,0.0]],
|
| 223 |
+
[[-0.5,0.0,0.0],[-0.25,-0.25,0.25],[-0.125,-0.125,0.125]],
|
| 224 |
+
[[0.0,0.25,-0.25],[0.375,-0.375,-0.375],[-0.125,0.125,0.125],[0.25,0.25,0.0]],
|
| 225 |
+
[[0.5,0.0,0.0],[0.25,-0.25,0.25],[-0.125,0.125,-0.125],[0.125,-0.125,0.125]],
|
| 226 |
+
[[0.125,-0.125,0.125],[0.25,-0.25,0.0],[0.25,-0.25,0.0]],
|
| 227 |
+
[[0.25,0.25,-0.25],[0.25,0.25,-0.25],[0.125,0.125,-0.125],[-0.125,-0.125,0.125]],
|
| 228 |
+
[[-0.0,0.0,0.5],[-0.25,-0.25,0.25],[-0.125,-0.125,0.125]],
|
| 229 |
+
[[0.125,0.125,0.125],[0.125,-0.125,0.125],[-0.125,0.125,0.125]],
|
| 230 |
+
[[-0.125,0.125,0.125],[0.125,-0.125,0.125]],
|
| 231 |
+
[[-0.375,-0.375,0.375],[-0.0,0.25,0.25],[0.125,0.125,-0.125],[-0.25,-0.0,-0.25]],
|
| 232 |
+
[[0.0,-0.25,0.25],[0.0,0.25,-0.25],[0.125,-0.125,0.125]],
|
| 233 |
+
[[0.125,-0.125,0.125],[-0.25,-0.0,-0.25],[0.25,0.0,0.25]],
|
| 234 |
+
[[0.125,-0.125,0.125],[0.125,-0.125,0.125]],
|
| 235 |
+
[[0.0,-0.5,0.0],[0.125,0.125,-0.125],[0.25,0.25,-0.25]],
|
| 236 |
+
[[0.0,-0.25,0.25],[0.0,0.25,-0.25]],
|
| 237 |
+
[[0.125,0.125,0.125],[0.125,-0.125,0.125]],
|
| 238 |
+
[[0.125,-0.125,0.125]],
|
| 239 |
+
[[-0.5,0.0,0.0],[-0.125,-0.125,-0.125],[-0.25,-0.25,-0.25]],
|
| 240 |
+
[[-0.5,0.0,0.0],[-0.125,-0.125,-0.125],[-0.25,-0.25,-0.25],[0.125,0.125,0.125]],
|
| 241 |
+
[[0.375,0.375,0.375],[0.0,0.25,-0.25],[-0.125,-0.125,-0.125],[-0.25,0.25,0.0]],
|
| 242 |
+
[[0.125,-0.125,-0.125],[0.25,-0.25,0.0],[0.25,-0.25,0.0]],
|
| 243 |
+
[[0.125,0.125,0.125],[0.375,0.375,0.375],[0.0,-0.25,0.25],[-0.25,0.0,0.25]],
|
| 244 |
+
[[-0.25,0.0,0.25],[-0.25,0.0,0.25],[0.125,-0.125,-0.125]],
|
| 245 |
+
[[0.0,-0.25,-0.25],[0.0,0.25,0.25],[-0.125,0.125,0.125]],
|
| 246 |
+
[[-0.125,0.125,0.125],[0.125,-0.125,-0.125]],
|
| 247 |
+
[[-0.125,-0.125,-0.125],[-0.25,-0.25,-0.25],[0.25,0.25,0.25],[0.125,0.125,0.125]],
|
| 248 |
+
[[-0.125,-0.125,0.125],[0.125,-0.125,0.125],[0.125,-0.125,-0.125]],
|
| 249 |
+
[[0.0,0.0,-0.5],[0.25,0.25,0.25],[-0.125,-0.125,-0.125]],
|
| 250 |
+
[[0.125,-0.125,0.125],[0.125,-0.125,-0.125]],
|
| 251 |
+
[[0.0,-0.5,0.0],[0.25,0.25,0.25],[0.125,0.125,0.125]],
|
| 252 |
+
[[-0.125,-0.125,0.125],[0.125,-0.125,-0.125]],
|
| 253 |
+
[[0.0,-0.25,-0.25],[0.0,0.25,0.25]],
|
| 254 |
+
[[0.125,-0.125,-0.125]],
|
| 255 |
+
[[0.5,0.0,0.0],[0.5,0.0,0.0]],
|
| 256 |
+
[[-0.5,0.0,0.0],[-0.25,0.25,0.25],[-0.125,0.125,0.125]],
|
| 257 |
+
[[0.5,0.0,0.0],[0.25,-0.25,0.25],[-0.125,0.125,-0.125]],
|
| 258 |
+
[[0.25,-0.25,0.0],[0.25,-0.25,0.0]],
|
| 259 |
+
[[0.5,0.0,0.0],[-0.25,-0.25,0.25],[-0.125,-0.125,0.125]],
|
| 260 |
+
[[-0.25,0.0,0.25],[-0.25,0.0,0.25]],
|
| 261 |
+
[[0.125,0.125,0.125],[-0.125,0.125,0.125]],
|
| 262 |
+
[[-0.125,0.125,0.125]],
|
| 263 |
+
[[0.5,0.0,-0.0],[0.25,0.25,0.25],[0.125,0.125,0.125]],
|
| 264 |
+
[[0.125,-0.125,0.125],[-0.125,-0.125,0.125]],
|
| 265 |
+
[[-0.25,-0.0,-0.25],[0.25,0.0,0.25]],
|
| 266 |
+
[[0.125,-0.125,0.125]],
|
| 267 |
+
[[-0.25,-0.25,0.0],[0.25,0.25,-0.0]],
|
| 268 |
+
[[-0.125,-0.125,0.125]],
|
| 269 |
+
[[0.125,0.125,0.125]],
|
| 270 |
+
[[0,0,0]]]
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def compute_surface_distances(mask_gt, mask_pred, spacing_mm):
|
| 274 |
+
"""Compute closest distances from all surface points to the other surface.
|
| 275 |
+
|
| 276 |
+
Finds all surface elements "surfels" in the ground truth mask `mask_gt` and
|
| 277 |
+
the predicted mask `mask_pred`, computes their area in mm^2 and the distance
|
| 278 |
+
to the closest point on the other surface. It returns two sorted lists of
|
| 279 |
+
distances together with the corresponding surfel areas. If one of the masks
|
| 280 |
+
is empty, the corresponding lists are empty and all distances in the other
|
| 281 |
+
list are `inf`
|
| 282 |
+
|
| 283 |
+
Args:
|
| 284 |
+
mask_gt: 3-dim Numpy array of type bool. The ground truth mask.
|
| 285 |
+
mask_pred: 3-dim Numpy array of type bool. The predicted mask.
|
| 286 |
+
spacing_mm: 3-element list-like structure. Voxel spacing in x0, x1 and x2
|
| 287 |
+
direction
|
| 288 |
+
|
| 289 |
+
Returns:
|
| 290 |
+
A dict with
|
| 291 |
+
"distances_gt_to_pred": 1-dim numpy array of type float. The distances in mm
|
| 292 |
+
from all ground truth surface elements to the predicted surface,
|
| 293 |
+
sorted from smallest to largest
|
| 294 |
+
"distances_pred_to_gt": 1-dim numpy array of type float. The distances in mm
|
| 295 |
+
from all predicted surface elements to the ground truth surface,
|
| 296 |
+
sorted from smallest to largest
|
| 297 |
+
"surfel_areas_gt": 1-dim numpy array of type float. The area in mm^2 of
|
| 298 |
+
the ground truth surface elements in the same order as
|
| 299 |
+
distances_gt_to_pred
|
| 300 |
+
"surfel_areas_pred": 1-dim numpy array of type float. The area in mm^2 of
|
| 301 |
+
the predicted surface elements in the same order as
|
| 302 |
+
distances_pred_to_gt
|
| 303 |
+
|
| 304 |
+
"""
|
| 305 |
+
|
| 306 |
+
# compute the area for all 256 possible surface elements
|
| 307 |
+
# (given a 2x2x2 neighbourhood) according to the spacing_mm
|
| 308 |
+
neighbour_code_to_surface_area = np.zeros([256])
|
| 309 |
+
for code in range(256):
|
| 310 |
+
normals = np.array(neighbour_code_to_normals[code])
|
| 311 |
+
sum_area = 0
|
| 312 |
+
for normal_idx in range(normals.shape[0]):
|
| 313 |
+
# normal vector
|
| 314 |
+
n = np.zeros([3])
|
| 315 |
+
n[0] = normals[normal_idx,0] * spacing_mm[1] * spacing_mm[2]
|
| 316 |
+
n[1] = normals[normal_idx,1] * spacing_mm[0] * spacing_mm[2]
|
| 317 |
+
n[2] = normals[normal_idx,2] * spacing_mm[0] * spacing_mm[1]
|
| 318 |
+
area = np.linalg.norm(n)
|
| 319 |
+
sum_area += area
|
| 320 |
+
neighbour_code_to_surface_area[code] = sum_area
|
| 321 |
+
|
| 322 |
+
# compute the bounding box of the masks to trim
|
| 323 |
+
# the volume to the smallest possible processing subvolume
|
| 324 |
+
mask_all = mask_gt | mask_pred
|
| 325 |
+
bbox_min = np.zeros(3, np.int64)
|
| 326 |
+
bbox_max = np.zeros(3, np.int64)
|
| 327 |
+
|
| 328 |
+
# max projection to the x0-axis
|
| 329 |
+
proj_0 = np.max(np.max(mask_all, axis=2), axis=1)
|
| 330 |
+
idx_nonzero_0 = np.nonzero(proj_0)[0]
|
| 331 |
+
if len(idx_nonzero_0) == 0:
|
| 332 |
+
return {"distances_gt_to_pred": np.array([]),
|
| 333 |
+
"distances_pred_to_gt": np.array([]),
|
| 334 |
+
"surfel_areas_gt": np.array([]),
|
| 335 |
+
"surfel_areas_pred": np.array([])}
|
| 336 |
+
|
| 337 |
+
bbox_min[0] = np.min(idx_nonzero_0)
|
| 338 |
+
bbox_max[0] = np.max(idx_nonzero_0)
|
| 339 |
+
|
| 340 |
+
# max projection to the x1-axis
|
| 341 |
+
proj_1 = np.max(np.max(mask_all, axis=2), axis=0)
|
| 342 |
+
idx_nonzero_1 = np.nonzero(proj_1)[0]
|
| 343 |
+
bbox_min[1] = np.min(idx_nonzero_1)
|
| 344 |
+
bbox_max[1] = np.max(idx_nonzero_1)
|
| 345 |
+
|
| 346 |
+
# max projection to the x2-axis
|
| 347 |
+
proj_2 = np.max(np.max(mask_all, axis=1), axis=0)
|
| 348 |
+
idx_nonzero_2 = np.nonzero(proj_2)[0]
|
| 349 |
+
bbox_min[2] = np.min(idx_nonzero_2)
|
| 350 |
+
bbox_max[2] = np.max(idx_nonzero_2)
|
| 351 |
+
|
| 352 |
+
# print("bounding box min = {}".format(bbox_min))
|
| 353 |
+
# print("bounding box max = {}".format(bbox_max))
|
| 354 |
+
|
| 355 |
+
# crop the processing subvolume.
|
| 356 |
+
# we need to zeropad the cropped region with 1 voxel at the lower,
|
| 357 |
+
# the right and the back side. This is required to obtain the "full"
|
| 358 |
+
# convolution result with the 2x2x2 kernel
|
| 359 |
+
cropmask_gt = np.zeros((bbox_max - bbox_min)+2, np.uint8)
|
| 360 |
+
cropmask_pred = np.zeros((bbox_max - bbox_min)+2, np.uint8)
|
| 361 |
+
|
| 362 |
+
cropmask_gt[0:-1, 0:-1, 0:-1] = mask_gt[bbox_min[0]:bbox_max[0]+1,
|
| 363 |
+
bbox_min[1]:bbox_max[1]+1,
|
| 364 |
+
bbox_min[2]:bbox_max[2]+1]
|
| 365 |
+
|
| 366 |
+
cropmask_pred[0:-1, 0:-1, 0:-1] = mask_pred[bbox_min[0]:bbox_max[0]+1,
|
| 367 |
+
bbox_min[1]:bbox_max[1]+1,
|
| 368 |
+
bbox_min[2]:bbox_max[2]+1]
|
| 369 |
+
|
| 370 |
+
# compute the neighbour code (local binary pattern) for each voxel
|
| 371 |
+
# the resultsing arrays are spacially shifted by minus half a voxel in each axis.
|
| 372 |
+
# i.e. the points are located at the corners of the original voxels
|
| 373 |
+
kernel = np.array([[[128,64],
|
| 374 |
+
[32,16]],
|
| 375 |
+
[[8,4],
|
| 376 |
+
[2,1]]])
|
| 377 |
+
neighbour_code_map_gt = scipy.ndimage.filters.correlate(cropmask_gt.astype(np.uint8), kernel, mode="constant", cval=0)
|
| 378 |
+
neighbour_code_map_pred = scipy.ndimage.filters.correlate(cropmask_pred.astype(np.uint8), kernel, mode="constant", cval=0)
|
| 379 |
+
|
| 380 |
+
# create masks with the surface voxels
|
| 381 |
+
borders_gt = ((neighbour_code_map_gt != 0) & (neighbour_code_map_gt != 255))
|
| 382 |
+
borders_pred = ((neighbour_code_map_pred != 0) & (neighbour_code_map_pred != 255))
|
| 383 |
+
|
| 384 |
+
# compute the distance transform (closest distance of each voxel to the surface voxels)
|
| 385 |
+
if borders_gt.any():
|
| 386 |
+
distmap_gt = scipy.ndimage.morphology.distance_transform_edt(~borders_gt, sampling=spacing_mm)
|
| 387 |
+
else:
|
| 388 |
+
distmap_gt = np.Inf * np.ones(borders_gt.shape)
|
| 389 |
+
|
| 390 |
+
if borders_pred.any():
|
| 391 |
+
distmap_pred = scipy.ndimage.morphology.distance_transform_edt(~borders_pred, sampling=spacing_mm)
|
| 392 |
+
else:
|
| 393 |
+
distmap_pred = np.Inf * np.ones(borders_pred.shape)
|
| 394 |
+
|
| 395 |
+
# compute the area of each surface element
|
| 396 |
+
surface_area_map_gt = neighbour_code_to_surface_area[neighbour_code_map_gt]
|
| 397 |
+
surface_area_map_pred = neighbour_code_to_surface_area[neighbour_code_map_pred]
|
| 398 |
+
|
| 399 |
+
# create a list of all surface elements with distance and area
|
| 400 |
+
distances_gt_to_pred = distmap_pred[borders_gt]
|
| 401 |
+
distances_pred_to_gt = distmap_gt[borders_pred]
|
| 402 |
+
surfel_areas_gt = surface_area_map_gt[borders_gt]
|
| 403 |
+
surfel_areas_pred = surface_area_map_pred[borders_pred]
|
| 404 |
+
|
| 405 |
+
# sort them by distance
|
| 406 |
+
if distances_gt_to_pred.shape != (0,):
|
| 407 |
+
sorted_surfels_gt = np.array(sorted(zip(distances_gt_to_pred, surfel_areas_gt)))
|
| 408 |
+
distances_gt_to_pred = sorted_surfels_gt[:,0]
|
| 409 |
+
surfel_areas_gt = sorted_surfels_gt[:,1]
|
| 410 |
+
|
| 411 |
+
if distances_pred_to_gt.shape != (0,):
|
| 412 |
+
sorted_surfels_pred = np.array(sorted(zip(distances_pred_to_gt, surfel_areas_pred)))
|
| 413 |
+
distances_pred_to_gt = sorted_surfels_pred[:,0]
|
| 414 |
+
surfel_areas_pred = sorted_surfels_pred[:,1]
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
return {"distances_gt_to_pred": distances_gt_to_pred,
|
| 418 |
+
"distances_pred_to_gt": distances_pred_to_gt,
|
| 419 |
+
"surfel_areas_gt": surfel_areas_gt,
|
| 420 |
+
"surfel_areas_pred": surfel_areas_pred}
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
def compute_average_surface_distance(surface_distances):
|
| 424 |
+
distances_gt_to_pred = surface_distances["distances_gt_to_pred"]
|
| 425 |
+
distances_pred_to_gt = surface_distances["distances_pred_to_gt"]
|
| 426 |
+
surfel_areas_gt = surface_distances["surfel_areas_gt"]
|
| 427 |
+
surfel_areas_pred = surface_distances["surfel_areas_pred"]
|
| 428 |
+
average_distance_gt_to_pred = np.sum( distances_gt_to_pred * surfel_areas_gt) / np.sum(surfel_areas_gt)
|
| 429 |
+
average_distance_pred_to_gt = np.sum( distances_pred_to_gt * surfel_areas_pred) / np.sum(surfel_areas_pred)
|
| 430 |
+
return (average_distance_gt_to_pred, average_distance_pred_to_gt)
|
| 431 |
+
|
| 432 |
+
def compute_robust_hausdorff(surface_distances, percent):
|
| 433 |
+
distances_gt_to_pred = surface_distances["distances_gt_to_pred"]
|
| 434 |
+
distances_pred_to_gt = surface_distances["distances_pred_to_gt"]
|
| 435 |
+
surfel_areas_gt = surface_distances["surfel_areas_gt"]
|
| 436 |
+
surfel_areas_pred = surface_distances["surfel_areas_pred"]
|
| 437 |
+
if len(distances_gt_to_pred) > 0:
|
| 438 |
+
surfel_areas_cum_gt = np.cumsum(surfel_areas_gt) / np.sum(surfel_areas_gt)
|
| 439 |
+
idx = np.searchsorted(surfel_areas_cum_gt, percent/100.0)
|
| 440 |
+
perc_distance_gt_to_pred = distances_gt_to_pred[min(idx, len(distances_gt_to_pred)-1)]
|
| 441 |
+
else:
|
| 442 |
+
perc_distance_gt_to_pred = np.Inf
|
| 443 |
+
|
| 444 |
+
if len(distances_pred_to_gt) > 0:
|
| 445 |
+
surfel_areas_cum_pred = np.cumsum(surfel_areas_pred) / np.sum(surfel_areas_pred)
|
| 446 |
+
idx = np.searchsorted(surfel_areas_cum_pred, percent/100.0)
|
| 447 |
+
perc_distance_pred_to_gt = distances_pred_to_gt[min(idx, len(distances_pred_to_gt)-1)]
|
| 448 |
+
else:
|
| 449 |
+
perc_distance_pred_to_gt = np.Inf
|
| 450 |
+
|
| 451 |
+
return max( perc_distance_gt_to_pred, perc_distance_pred_to_gt)
|
| 452 |
+
|
| 453 |
+
def compute_surface_overlap_at_tolerance(surface_distances, tolerance_mm):
|
| 454 |
+
distances_gt_to_pred = surface_distances["distances_gt_to_pred"]
|
| 455 |
+
distances_pred_to_gt = surface_distances["distances_pred_to_gt"]
|
| 456 |
+
surfel_areas_gt = surface_distances["surfel_areas_gt"]
|
| 457 |
+
surfel_areas_pred = surface_distances["surfel_areas_pred"]
|
| 458 |
+
rel_overlap_gt = np.sum(surfel_areas_gt[distances_gt_to_pred <= tolerance_mm]) / np.sum(surfel_areas_gt)
|
| 459 |
+
rel_overlap_pred = np.sum(surfel_areas_pred[distances_pred_to_gt <= tolerance_mm]) / np.sum(surfel_areas_pred)
|
| 460 |
+
return (rel_overlap_gt, rel_overlap_pred)
|
| 461 |
+
|
| 462 |
+
def compute_surface_dice_at_tolerance(surface_distances, tolerance_mm):
|
| 463 |
+
distances_gt_to_pred = surface_distances["distances_gt_to_pred"]
|
| 464 |
+
distances_pred_to_gt = surface_distances["distances_pred_to_gt"]
|
| 465 |
+
surfel_areas_gt = surface_distances["surfel_areas_gt"]
|
| 466 |
+
surfel_areas_pred = surface_distances["surfel_areas_pred"]
|
| 467 |
+
overlap_gt = np.sum(surfel_areas_gt[distances_gt_to_pred <= tolerance_mm])
|
| 468 |
+
overlap_pred = np.sum(surfel_areas_pred[distances_pred_to_gt <= tolerance_mm])
|
| 469 |
+
surface_dice = (overlap_gt + overlap_pred) / (
|
| 470 |
+
np.sum(surfel_areas_gt) + np.sum(surfel_areas_pred))
|
| 471 |
+
return surface_dice
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
def compute_dice_coefficient(mask_gt, mask_pred):
|
| 475 |
+
"""Compute soerensen-dice coefficient.
|
| 476 |
+
|
| 477 |
+
compute the soerensen-dice coefficient between the ground truth mask `mask_gt`
|
| 478 |
+
and the predicted mask `mask_pred`.
|
| 479 |
+
|
| 480 |
+
Args:
|
| 481 |
+
mask_gt: 3-dim Numpy array of type bool. The ground truth mask.
|
| 482 |
+
mask_pred: 3-dim Numpy array of type bool. The predicted mask.
|
| 483 |
+
|
| 484 |
+
Returns:
|
| 485 |
+
the dice coeffcient as float. If both masks are empty, the result is NaN
|
| 486 |
+
"""
|
| 487 |
+
volume_sum = mask_gt.sum() + mask_pred.sum()
|
| 488 |
+
if volume_sum == 0:
|
| 489 |
+
return np.NaN
|
| 490 |
+
volume_intersect = (mask_gt & mask_pred).sum()
|
| 491 |
+
return 2*volume_intersect / volume_sum
|
| 492 |
+
|
evaluate/__init__.py
ADDED
|
File without changes
|
evaluate/evaluator.py
ADDED
|
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch.cuda.amp import autocast as autocast
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
from einops import rearrange, repeat, reduce
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pandas as pd
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
import nibabel as nib
|
| 12 |
+
import shutil
|
| 13 |
+
import pickle
|
| 14 |
+
from scipy.ndimage import gaussian_filter
|
| 15 |
+
import torch.distributed as dist
|
| 16 |
+
|
| 17 |
+
from evaluate.metric import calculate_metric_percase
|
| 18 |
+
from evaluate.merge_after_evaluate import merge
|
| 19 |
+
from train.dist import is_master
|
| 20 |
+
|
| 21 |
+
def compute_gaussian(tile_size, sigma_scale: float = 1. / 8, value_scaling_factor: float = 10, dtype=np.float16):
|
| 22 |
+
tmp = np.zeros(tile_size)
|
| 23 |
+
center_coords = [i // 2 for i in tile_size]
|
| 24 |
+
sigmas = [i * sigma_scale for i in tile_size]
|
| 25 |
+
tmp[tuple(center_coords)] = 1
|
| 26 |
+
gaussian_importance_map = gaussian_filter(tmp, sigmas, 0, mode='constant', cval=0)
|
| 27 |
+
|
| 28 |
+
# gaussian_importance_map = torch.from_numpy(gaussian_importance_map)
|
| 29 |
+
|
| 30 |
+
gaussian_importance_map = gaussian_importance_map / np.max(gaussian_importance_map) * value_scaling_factor
|
| 31 |
+
gaussian_importance_map = gaussian_importance_map.astype(dtype)
|
| 32 |
+
|
| 33 |
+
# gaussian_importance_map cannot be 0, otherwise we may end up with nans!
|
| 34 |
+
gaussian_importance_map[gaussian_importance_map == 0] = np.min(
|
| 35 |
+
gaussian_importance_map[gaussian_importance_map != 0])
|
| 36 |
+
|
| 37 |
+
return gaussian_importance_map
|
| 38 |
+
|
| 39 |
+
def evaluate(model,
|
| 40 |
+
text_encoder,
|
| 41 |
+
device,
|
| 42 |
+
testset,
|
| 43 |
+
testloader,
|
| 44 |
+
dice_score,
|
| 45 |
+
nsd_score,
|
| 46 |
+
csv_path,
|
| 47 |
+
resume,
|
| 48 |
+
save_interval,
|
| 49 |
+
visualization):
|
| 50 |
+
|
| 51 |
+
# if to store pred、gt、img (as nii.gz
|
| 52 |
+
if visualization:
|
| 53 |
+
nib_dir = csv_path.replace('.csv', '')
|
| 54 |
+
|
| 55 |
+
# collate in master process
|
| 56 |
+
if is_master():
|
| 57 |
+
# datasets --> labels --> metrics
|
| 58 |
+
datasets_labels_metrics = {} # {'COVID19':{'covid19_infection':{'dice':[0.8, 0.9, ...], ...} ...}, ...}
|
| 59 |
+
|
| 60 |
+
# datasets --> samples --> labels --> metrics
|
| 61 |
+
samples_labels_metrics = {} # {'COVID19':{'0.npy':{'covid19_infection':{'dice':0.8, ...} ...}, ...} 记录每个dataset里的sample(行)
|
| 62 |
+
|
| 63 |
+
# datsets --> labels
|
| 64 |
+
datasets_labels_sets = {} # {'COVID19':set('covid19_infection', ...), ...} 记录每个dataset里的label种类(列)
|
| 65 |
+
|
| 66 |
+
# accumulate scores of each sample in each process
|
| 67 |
+
results_of_samples = [] # each element : [dataset_name, modality, sample_id, scores_of_labels(dict), label_names]
|
| 68 |
+
|
| 69 |
+
# load results from an interrupted eval (only in master process)
|
| 70 |
+
if resume and is_master():
|
| 71 |
+
root_dir = os.path.dirname(csv_path)
|
| 72 |
+
prefix = os.path.basename(csv_path).replace('.csv', '_tmp_rank') # xxx/test/step_xxx.csv --> step_xxx_tmp_rank
|
| 73 |
+
pkl_to_del = []
|
| 74 |
+
for f in os.listdir(root_dir):
|
| 75 |
+
if prefix in f:
|
| 76 |
+
# load list of results
|
| 77 |
+
pkl_path = f'{root_dir}/{f}'
|
| 78 |
+
with open(pkl_path, 'rb') as f:
|
| 79 |
+
results_of_samples += pickle.load(f)
|
| 80 |
+
print(f'Load results from {pkl_path}')
|
| 81 |
+
pkl_to_del.append(pkl_path)
|
| 82 |
+
|
| 83 |
+
# there may be duplication? We leave the deduplication to the final merge
|
| 84 |
+
# merge all the loaded samples, del the tmp pickle files in previous evaluation task
|
| 85 |
+
for pkl_path in pkl_to_del:
|
| 86 |
+
os.remove(pkl_path)
|
| 87 |
+
print(f'Del {pkl_path}')
|
| 88 |
+
merge_pkl = csv_path.replace('.csv', f'_tmp_rank0.pkl')
|
| 89 |
+
with open(merge_pkl, 'wb') as f:
|
| 90 |
+
pickle.dump(results_of_samples, f)
|
| 91 |
+
print(f'Load results of {len(results_of_samples)} samples, Merge into {merge_pkl}')
|
| 92 |
+
|
| 93 |
+
model.eval()
|
| 94 |
+
text_encoder.eval()
|
| 95 |
+
|
| 96 |
+
with torch.no_grad():
|
| 97 |
+
|
| 98 |
+
data_time = 0
|
| 99 |
+
pred_time = 0
|
| 100 |
+
metric_time = 0
|
| 101 |
+
|
| 102 |
+
avg_patch_batch_num = 0
|
| 103 |
+
avg_query_batch_num = 0
|
| 104 |
+
|
| 105 |
+
# in ddp, only master process display the progress bar
|
| 106 |
+
if is_master():
|
| 107 |
+
testloader = tqdm(testloader, disable=False)
|
| 108 |
+
else:
|
| 109 |
+
testloader = tqdm(testloader, disable=True)
|
| 110 |
+
|
| 111 |
+
# gaussian kernel to accumulate predcition
|
| 112 |
+
gaussian = torch.tensor(compute_gaussian((288, 288, 96))).to(device) # hwd
|
| 113 |
+
|
| 114 |
+
end_time = time.time()
|
| 115 |
+
for sample in testloader: # in evaluation/inference, a "batch" in loader is a volume
|
| 116 |
+
# data loading
|
| 117 |
+
dataset_name = sample['dataset_name']
|
| 118 |
+
sample_id = sample['sample_id']
|
| 119 |
+
batched_patches = sample['batched_patches']
|
| 120 |
+
batched_y1y2_x1x2_z1z2 = sample['batched_y1y2_x1x2_z1z2']
|
| 121 |
+
labels = sample['labels']
|
| 122 |
+
gt_segmentation = sample['gt_segmentation'].numpy() # n h w d
|
| 123 |
+
modality = sample['modality']
|
| 124 |
+
image_path = sample['image_path']
|
| 125 |
+
|
| 126 |
+
n,h,w,d = gt_segmentation.shape
|
| 127 |
+
prediction = torch.zeros((n, h, w, d))
|
| 128 |
+
accumulation = torch.zeros((n, h, w, d))
|
| 129 |
+
|
| 130 |
+
data_time += (time.time()-end_time)
|
| 131 |
+
end_time = time.time()
|
| 132 |
+
|
| 133 |
+
with autocast():
|
| 134 |
+
|
| 135 |
+
queries = text_encoder(labels, modality)
|
| 136 |
+
|
| 137 |
+
# for each batch of patches, query with all labels
|
| 138 |
+
for patches, y1y2_x1x2_z1z2_ls in zip(batched_patches, batched_y1y2_x1x2_z1z2): # [b, c, h, w, d]
|
| 139 |
+
patches = patches.to(device=device)
|
| 140 |
+
prediction_patch = model(queries=queries, image_input=patches, train_mode=False)
|
| 141 |
+
prediction_patch = torch.sigmoid(prediction_patch) # bnhwd
|
| 142 |
+
prediction_patch = prediction_patch.detach() # .cpu().numpy()
|
| 143 |
+
|
| 144 |
+
# fill in
|
| 145 |
+
for b in range(len(y1y2_x1x2_z1z2_ls)):
|
| 146 |
+
y1, y2, x1, x2, z1, z2 = y1y2_x1x2_z1z2_ls[b]
|
| 147 |
+
|
| 148 |
+
# gaussian accumulation
|
| 149 |
+
tmp = prediction_patch[b, :, :y2-y1, :x2-x1, :z2-z1] * gaussian[:y2-y1, :x2-x1, :z2-z1] # on gpu
|
| 150 |
+
prediction[:, y1:y2, x1:x2, z1:z2] += tmp.cpu()
|
| 151 |
+
accumulation[:, y1:y2, x1:x2, z1:z2] += gaussian[:y2-y1, :x2-x1, :z2-z1].cpu()
|
| 152 |
+
|
| 153 |
+
pred_time += (time.time()-end_time)
|
| 154 |
+
end_time = time.time()
|
| 155 |
+
|
| 156 |
+
# avg
|
| 157 |
+
prediction = prediction / accumulation
|
| 158 |
+
prediction = torch.where(prediction>0.5, 1.0, 0.0)
|
| 159 |
+
prediction = prediction.numpy()
|
| 160 |
+
|
| 161 |
+
# cal metrics : [{'dice':x, ...}, ...]
|
| 162 |
+
scores = []
|
| 163 |
+
for j in range(len(labels)):
|
| 164 |
+
scores.append(calculate_metric_percase(prediction[j, :, :, :], gt_segmentation[j, :, :, :], dice_score, nsd_score)) # {'dice':0.9, 'nsd':0.8} 每个label一个dict
|
| 165 |
+
|
| 166 |
+
# visualization
|
| 167 |
+
if visualization:
|
| 168 |
+
Path(f'{nib_dir}/{dataset_name}').mkdir(exist_ok=True, parents=True)
|
| 169 |
+
# 将image、gt和prediction保存下来
|
| 170 |
+
results = np.zeros((h, w, d)) # hwd
|
| 171 |
+
for j, label in enumerate(labels):
|
| 172 |
+
results += prediction[j, :, :, :] * (j+1) # 0 --> 1 (skip background)
|
| 173 |
+
Path(f'{nib_dir}/{dataset_name}/seg_{sample_id}').mkdir(exist_ok=True, parents=True)
|
| 174 |
+
# 每个label单独一个nii.gz
|
| 175 |
+
segobj = nib.nifti2.Nifti1Image(prediction[j, :, :, :], np.eye(4))
|
| 176 |
+
nib.save(segobj, f'{nib_dir}/{dataset_name}/seg_{sample_id}/{label}.nii.gz')
|
| 177 |
+
segobj = nib.nifti2.Nifti1Image(results, np.eye(4))
|
| 178 |
+
nib.save(segobj, f'{nib_dir}/{dataset_name}/seg_{sample_id}.nii.gz')
|
| 179 |
+
|
| 180 |
+
image = testset.load_image(image_path)
|
| 181 |
+
image = np.squeeze(image)
|
| 182 |
+
imgobj = nib.nifti2.Nifti1Image(image, np.eye(4))
|
| 183 |
+
nib.save(imgobj, f'{nib_dir}/{dataset_name}/img_{sample_id}.nii.gz')
|
| 184 |
+
|
| 185 |
+
gt = np.zeros((h, w, d)) # hwd
|
| 186 |
+
for j, label in enumerate(labels):
|
| 187 |
+
gt += gt_segmentation[j, :, :, :] * (j+1) # 0 --> 1 (skip background)
|
| 188 |
+
Path(f'{nib_dir}/{dataset_name}/gt_{sample_id}').mkdir(exist_ok=True, parents=True)
|
| 189 |
+
# 每个label单独一个nii.gz
|
| 190 |
+
segobj = nib.nifti2.Nifti1Image(gt_segmentation[j, :, :, :], np.eye(4))
|
| 191 |
+
nib.save(segobj, f'{nib_dir}/{dataset_name}/gt_{sample_id}/{label}.nii.gz')
|
| 192 |
+
gtobj = nib.nifti2.Nifti1Image(gt, np.eye(4))
|
| 193 |
+
nib.save(gtobj, f'{nib_dir}/{dataset_name}/gt_{sample_id}.nii.gz')
|
| 194 |
+
|
| 195 |
+
metric_time += (time.time()-end_time)
|
| 196 |
+
end_time = time.time()
|
| 197 |
+
|
| 198 |
+
# accumulate
|
| 199 |
+
results_of_samples.append([dataset_name, modality, sample_id, scores, labels])
|
| 200 |
+
|
| 201 |
+
# save in each process regularly in case of interruption
|
| 202 |
+
if len(results_of_samples) % save_interval == 0:
|
| 203 |
+
with open(csv_path.replace('.csv', f'_tmp_rank{dist.get_rank()}.pkl'), 'wb') as f:
|
| 204 |
+
pickle.dump(results_of_samples, f)
|
| 205 |
+
|
| 206 |
+
"""
|
| 207 |
+
# gather results from all device to rank-0 (solution 1)
|
| 208 |
+
gather_results = [None for i in range(dist.get_world_size())]
|
| 209 |
+
dist.gather_object(
|
| 210 |
+
results_of_samples,
|
| 211 |
+
gather_results if dist.get_rank() == 0 else None,
|
| 212 |
+
dst = 0
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
if int(dist.get_rank()) == 0:
|
| 216 |
+
results_of_samples = [tmp for ls in results_of_samples for tmp in ls]
|
| 217 |
+
"""
|
| 218 |
+
|
| 219 |
+
avg_patch_batch_num /= len(testloader)
|
| 220 |
+
avg_query_batch_num /= len(testloader)
|
| 221 |
+
data_time /= len(testloader)
|
| 222 |
+
pred_time /= len(testloader)
|
| 223 |
+
metric_time /= len(testloader)
|
| 224 |
+
print(f'On Rank {dist.get_rank()}, each sample has {avg_patch_batch_num} batch of patches and {avg_query_batch_num} batch of queries, Data Time: {data_time}, Pred Time: {pred_time}, Dice Time: {metric_time}')
|
| 225 |
+
|
| 226 |
+
torch.cuda.empty_cache()
|
| 227 |
+
|
| 228 |
+
# save in each process (to a fnl pickle, also denoting this process ends)
|
| 229 |
+
with open(csv_path.replace('.csv', f'_fnl_rank{dist.get_rank()}.pkl'), 'wb') as f:
|
| 230 |
+
pickle.dump(results_of_samples, f)
|
| 231 |
+
|
| 232 |
+
# gather and record in rank 0 (solution 2)
|
| 233 |
+
if is_master():
|
| 234 |
+
|
| 235 |
+
# detect the finish of each process
|
| 236 |
+
while True:
|
| 237 |
+
all_process_finished = True
|
| 238 |
+
for rank_id in range(torch.distributed.get_world_size()):
|
| 239 |
+
if not os.path.exists(csv_path.replace('.csv', f'_fnl_rank{rank_id}.pkl')): # xxx_tmp_rankx.pkl
|
| 240 |
+
all_process_finished = False
|
| 241 |
+
break
|
| 242 |
+
if all_process_finished:
|
| 243 |
+
break
|
| 244 |
+
else:
|
| 245 |
+
time.sleep(10)
|
| 246 |
+
|
| 247 |
+
# read results of each process (samples may be duplicated due to the even distribution of ddp, check)
|
| 248 |
+
results_of_samples = []
|
| 249 |
+
for rank_id in range(torch.distributed.get_world_size()):
|
| 250 |
+
fnl_results_file = csv_path.replace('.csv', f'_fnl_rank{rank_id}.pkl')
|
| 251 |
+
tmp_results_file = csv_path.replace('.csv', f'_tmp_rank{rank_id}.pkl')
|
| 252 |
+
with open(fnl_results_file, 'rb') as f:
|
| 253 |
+
results_of_samples += pickle.load(f)
|
| 254 |
+
os.remove(fnl_results_file)
|
| 255 |
+
if os.path.exists(tmp_results_file):
|
| 256 |
+
os.remove(tmp_results_file)
|
| 257 |
+
|
| 258 |
+
# check duplication
|
| 259 |
+
unique_set = set()
|
| 260 |
+
deduplicated_results_of_samples = []
|
| 261 |
+
for dataset_name, modality, sample_id, scores, labels in results_of_samples:
|
| 262 |
+
if f'{dataset_name}/{sample_id}' not in unique_set:
|
| 263 |
+
unique_set.add(f'{dataset_name}/{sample_id}')
|
| 264 |
+
deduplicated_results_of_samples.append([dataset_name, modality, sample_id, scores, labels])
|
| 265 |
+
results_of_samples = deduplicated_results_of_samples
|
| 266 |
+
|
| 267 |
+
# save for tmp
|
| 268 |
+
with open(csv_path.replace('.csv', '.pkl'), 'wb') as f:
|
| 269 |
+
pickle.dump(results_of_samples, f)
|
| 270 |
+
|
| 271 |
+
# collate results
|
| 272 |
+
for dataset_name, modality, sample_id, scores, labels in results_of_samples: # [[dataset_name, modality, sample_id, scores_of_labels(dict), label_names], ...]
|
| 273 |
+
dataset_name = f'{dataset_name}({modality})'
|
| 274 |
+
|
| 275 |
+
if dataset_name not in datasets_labels_metrics:
|
| 276 |
+
datasets_labels_metrics[dataset_name] = {} # {'COVID19(CT)':{}}
|
| 277 |
+
if dataset_name not in datasets_labels_sets:
|
| 278 |
+
datasets_labels_sets[dataset_name] = set() # {'COVID19(CT)':set()}
|
| 279 |
+
if dataset_name not in samples_labels_metrics:
|
| 280 |
+
samples_labels_metrics[dataset_name] = {}
|
| 281 |
+
samples_labels_metrics[dataset_name][sample_id] = {} # {'COVID19(CT)':{'0':{}}}
|
| 282 |
+
|
| 283 |
+
for metric_dict, label in zip(scores, labels):
|
| 284 |
+
# accumulate metrics (for per dataset per class
|
| 285 |
+
# {'COVID19(CT)':{'covid19_infection':{'dice':[0.8, 0.9, ...], 'nsd':[0.8, 0.9, ...], ...} ...}, ...}
|
| 286 |
+
if label not in datasets_labels_metrics[dataset_name]:
|
| 287 |
+
datasets_labels_metrics[dataset_name][label] = {k:[v] for k,v in metric_dict.items()}
|
| 288 |
+
else:
|
| 289 |
+
for k,v in metric_dict.items():
|
| 290 |
+
datasets_labels_metrics[dataset_name][label][k].append(v)
|
| 291 |
+
|
| 292 |
+
# statistic labels
|
| 293 |
+
# {'COVID19(CT)':set('covid19_infection', ...)}
|
| 294 |
+
if label not in datasets_labels_sets[dataset_name]:
|
| 295 |
+
datasets_labels_sets[dataset_name].add(label)
|
| 296 |
+
|
| 297 |
+
# record metrics (for per dataset per sample per class
|
| 298 |
+
# {'COVID19':{'0.npy':{'covid19_infection':{'dice':0.8, 'nsd':0.9, ...} ...}, ...}
|
| 299 |
+
samples_labels_metrics[dataset_name][sample_id][label] = {k:v for k,v in metric_dict.items()}
|
| 300 |
+
|
| 301 |
+
# average and log (列为metrics,例如dice,nsd...)
|
| 302 |
+
# create a df like:
|
| 303 |
+
# {
|
| 304 |
+
# 'TotalSegmentator': [0.xx, 0.xx, ...] # 在T之前,这是一列
|
| 305 |
+
# 'TotalSegmentator, Lung': [0.68, 0.72, ...]
|
| 306 |
+
# }
|
| 307 |
+
# by defult, print the dice (1st metric) of each dataset
|
| 308 |
+
info = 'Metrics of Each Dataset:\n'
|
| 309 |
+
avg_df = {}
|
| 310 |
+
for dataset in datasets_labels_metrics.keys():
|
| 311 |
+
avg_df[dataset] = {k:[] for k in metric_dict.keys()} # 'TotalSegmentator(CT)': {'dice':[0.8, ...] 'nsd':[0.5, ...], ...}
|
| 312 |
+
for label in datasets_labels_metrics[dataset].keys():
|
| 313 |
+
avg_df[f'{dataset}, {label}'] = []
|
| 314 |
+
for metric in datasets_labels_metrics[dataset][label].keys():
|
| 315 |
+
label_metric = np.average(datasets_labels_metrics[dataset][label][metric])
|
| 316 |
+
avg_df[f'{dataset}, {label}'].append(label_metric) # 'TotalSegmentator, Lung': [0.68, 0.72, ...] list of num_metrics
|
| 317 |
+
avg_df[dataset][metric].append(label_metric)
|
| 318 |
+
avg_df[dataset] = {k:np.average(v) for k,v in avg_df[dataset].items()} # 'TotalSegmentator': {'dice':[0.8, ...] 'nsd':[0.5, ...], ...} --> 'TotalSegmentator': {'dice':0.x, 'nsd':0.x, ...}
|
| 319 |
+
info += f'{dataset} | '
|
| 320 |
+
for k ,v in avg_df[dataset].items():
|
| 321 |
+
info += f'{v}({k}) | '
|
| 322 |
+
info += '\n'
|
| 323 |
+
avg_df[dataset] = list(avg_df[dataset].values())
|
| 324 |
+
avg_df = pd.DataFrame(avg_df).T
|
| 325 |
+
avg_df.columns = list(metric_dict.keys()) # ['dice', 'nsd']
|
| 326 |
+
avg_df.to_csv(csv_path)
|
| 327 |
+
print(info)
|
| 328 |
+
|
| 329 |
+
# detailed log (nsd和dice,列为class label
|
| 330 |
+
# multi-sheet, two for each dataset
|
| 331 |
+
df_list = [['summary', avg_df]]
|
| 332 |
+
for dataset, label_set in datasets_labels_sets.items():
|
| 333 |
+
metric_df ={}
|
| 334 |
+
if dice_score:
|
| 335 |
+
metric_df['dice'] = {}
|
| 336 |
+
if nsd_score:
|
| 337 |
+
metric_df['nsd'] = {}
|
| 338 |
+
|
| 339 |
+
# create dfs like:
|
| 340 |
+
# {
|
| 341 |
+
# '0.npy': [0.xx, 0.xx, ...]
|
| 342 |
+
# ......
|
| 343 |
+
# }
|
| 344 |
+
|
| 345 |
+
# {'COVID19':{'0.npy':{'covid19_infection':{'dice':0.8, ...} ...}, ...}
|
| 346 |
+
for image_id, label_dict in samples_labels_metrics[dataset].items():
|
| 347 |
+
for metric in metric_df:
|
| 348 |
+
tmp = [] # one dice for each label in this dataset
|
| 349 |
+
for label in label_set:
|
| 350 |
+
score = label_dict[label][metric] if label in label_dict else -1
|
| 351 |
+
tmp.append(score)
|
| 352 |
+
metric_df[metric][image_id] = tmp
|
| 353 |
+
|
| 354 |
+
for metric, metric_df in metric_df.items():
|
| 355 |
+
metric_df = pd.DataFrame(metric_df).T
|
| 356 |
+
metric_df.columns = list(label_set)
|
| 357 |
+
df_list.append([dataset+f'({metric})', metric_df])
|
| 358 |
+
|
| 359 |
+
xlsx_path = csv_path.replace('.csv', '.xlsx')
|
| 360 |
+
with pd.ExcelWriter(xlsx_path) as writer:
|
| 361 |
+
for name, df in df_list:
|
| 362 |
+
# 将每个 DataFrame 写入一个 sheet(sheet name must be < 31)
|
| 363 |
+
if len(name) > 31:
|
| 364 |
+
name = name[len(name)-31:]
|
| 365 |
+
df.to_excel(writer, sheet_name=name, index=True)
|
| 366 |
+
|
| 367 |
+
# avg_dice_over_merged_labels, avg_nsd_over_merged_labels = merge(region_split_json, label_statistic_json, xlsx_path, xlsx_path)
|
| 368 |
+
|
| 369 |
+
os.remove(csv_path.replace('.csv', '.pkl'))
|
| 370 |
+
|
| 371 |
+
else:
|
| 372 |
+
|
| 373 |
+
pass
|
| 374 |
+
|
| 375 |
+
# avg_dice_over_merged_labels = avg_nsd_over_merged_labels = 0
|
| 376 |
+
|
| 377 |
+
return # avg_dice_over_merged_labels, avg_nsd_over_merged_labels
|
| 378 |
+
|
| 379 |
+
|
evaluate/merge_after_evaluate.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import openpyxl
|
| 5 |
+
|
| 6 |
+
def merge(mod_label_json, mod_label_statistic, xlsx2load, xlsx2save):
|
| 7 |
+
mod_lab2dice = {}
|
| 8 |
+
|
| 9 |
+
# Load the first sheet of the Excel file
|
| 10 |
+
excel_file_path = xlsx2load
|
| 11 |
+
df = pd.read_excel(excel_file_path, sheet_name=0)
|
| 12 |
+
has_nsd = True if len(df.columns) > 2 else False
|
| 13 |
+
|
| 14 |
+
# 将Dataset Merged 写入新的工作表
|
| 15 |
+
workbook = openpyxl.load_workbook(xlsx2load)
|
| 16 |
+
new_sheet = workbook.create_sheet(title='Dataset Merge', index=1)
|
| 17 |
+
new_sheet.cell(row=1, column=1, value='Dataset')
|
| 18 |
+
new_sheet.cell(row=1, column=2, value='Dice')
|
| 19 |
+
new_sheet.cell(row=1, column=3, value='NSD')
|
| 20 |
+
row = 2
|
| 21 |
+
for i in range(0, len(df)):
|
| 22 |
+
if ',' not in df.iloc[i, 0]:
|
| 23 |
+
new_sheet.cell(row=row, column=1, value=df.iloc[i, 0])
|
| 24 |
+
new_sheet.cell(row=row, column=2, value=df.iloc[i, 1])
|
| 25 |
+
if has_nsd:
|
| 26 |
+
new_sheet.cell(row=row, column=3, value=df.iloc[i, 2])
|
| 27 |
+
row += 1
|
| 28 |
+
|
| 29 |
+
# with pd.ExcelWriter(xlsx2save, engine='openpyxl', mode='a', if_sheet_exists='new') as writer:
|
| 30 |
+
# filtered_df.to_excel(writer, sheet_name='Dataset Merge', index=False)
|
| 31 |
+
|
| 32 |
+
# 选取前两列
|
| 33 |
+
dataset_label_ls = df.iloc[:, 0]
|
| 34 |
+
dice_ls = df.iloc[:, 1]
|
| 35 |
+
nsd_ls = df.iloc[:, 2] if has_nsd else [0] * len(df)
|
| 36 |
+
|
| 37 |
+
for dataset_modality_label, dice, nsd in zip(dataset_label_ls, dice_ls, nsd_ls): # MSD_Pancreas(ct), pancreas 0.89
|
| 38 |
+
if ', ' not in dataset_modality_label:
|
| 39 |
+
continue
|
| 40 |
+
dataset_modality, label = dataset_modality_label.split(', ')
|
| 41 |
+
label = label.lower() # pancreas
|
| 42 |
+
# label = merge_label(label)
|
| 43 |
+
modality = dataset_modality.split('(')[-1].split(')')[0] # ct
|
| 44 |
+
|
| 45 |
+
# unique id : modality_label
|
| 46 |
+
mod_lab = f'{modality}_{label}'
|
| 47 |
+
|
| 48 |
+
# accumulate : dice and where the dice comes from (dataset, label, modality)
|
| 49 |
+
if mod_lab not in mod_lab2dice:
|
| 50 |
+
mod_lab2dice[mod_lab] = {'dice':[], 'nsd':[], 'merge':[]}
|
| 51 |
+
mod_lab2dice[mod_lab]['dice'].append(dice)
|
| 52 |
+
mod_lab2dice[mod_lab]['nsd'].append(nsd)
|
| 53 |
+
mod_lab2dice[mod_lab]['merge'].append(dataset_modality_label)
|
| 54 |
+
|
| 55 |
+
# retrieval regions
|
| 56 |
+
with open(mod_label_json, 'r') as f:
|
| 57 |
+
dict = json.load(f)
|
| 58 |
+
region2label = dict['region_based']
|
| 59 |
+
for region, label_ls in region2label.items():
|
| 60 |
+
region2label[region] = [mod_lab.split('_')[-1] for mod_lab in label_ls] # 去除modality
|
| 61 |
+
region2label['abnormal'] = [mod_lab.split('_')[-1] for mod_lab in dict['abnormal']]
|
| 62 |
+
|
| 63 |
+
region_dice_ls = {k:[] for k in region2label.keys()} # {'brain':[0.9, ...], ...}
|
| 64 |
+
region_nsd_ls = {k:[] for k in region2label.keys()} # {'brain':[0.9, ...], ...}
|
| 65 |
+
region_merge_ls = {k:[] for k in region2label.keys()} # {'brain':['frontal lobe', ...], ...}
|
| 66 |
+
|
| 67 |
+
mod_lab_ls = []
|
| 68 |
+
dice_ls = []
|
| 69 |
+
nsd_ls = []
|
| 70 |
+
merge_ls = []
|
| 71 |
+
region_ls = []
|
| 72 |
+
for mod_lab, dict in mod_lab2dice.items():
|
| 73 |
+
label = mod_lab.split('_')[-1]
|
| 74 |
+
mod_lab_ls.append(mod_lab)
|
| 75 |
+
dice_ls.append(sum(dict['dice'])/len(dict['dice']))
|
| 76 |
+
nsd_ls.append(sum(dict['nsd'])/len(dict['nsd']))
|
| 77 |
+
merge_ls.append(' / '.join(dict['merge']))
|
| 78 |
+
|
| 79 |
+
# find region
|
| 80 |
+
if label in region2label['abnormal']:
|
| 81 |
+
region_dice_ls['abnormal'].append(dice_ls[-1])
|
| 82 |
+
region_nsd_ls['abnormal'].append(nsd_ls[-1])
|
| 83 |
+
region_merge_ls['abnormal'].append(mod_lab)
|
| 84 |
+
region_ls.append('abnormal')
|
| 85 |
+
else:
|
| 86 |
+
found = False
|
| 87 |
+
for region, labels_in_region in region2label.items():
|
| 88 |
+
if label in labels_in_region:
|
| 89 |
+
region_dice_ls[region].append(dice_ls[-1])
|
| 90 |
+
region_nsd_ls[region].append(nsd_ls[-1])
|
| 91 |
+
region_merge_ls[region].append(mod_lab)
|
| 92 |
+
region_ls.append(region)
|
| 93 |
+
found = True
|
| 94 |
+
break
|
| 95 |
+
if not found:
|
| 96 |
+
print(label)
|
| 97 |
+
region_ls.append('unknown')
|
| 98 |
+
|
| 99 |
+
df = pd.DataFrame({
|
| 100 |
+
'Modality_Label': mod_lab_ls,
|
| 101 |
+
'Dice': dice_ls,
|
| 102 |
+
'NSD': nsd_ls,
|
| 103 |
+
'Merge': merge_ls,
|
| 104 |
+
'Region': region_ls
|
| 105 |
+
})
|
| 106 |
+
|
| 107 |
+
#book = openpyxl.load_workbook(xlsx2save)
|
| 108 |
+
#writer = pd.ExcelWriter(xlsx2save, engine='openpyxl')
|
| 109 |
+
#writer.book = book
|
| 110 |
+
|
| 111 |
+
# with pd.ExcelWriter(xlsx2save, engine='openpyxl', mode='a', if_sheet_exists='new') as writer:
|
| 112 |
+
# df.to_excel(writer, sheet_name='Label Merge', index=False)
|
| 113 |
+
|
| 114 |
+
# 写上anno num和repeat ratio
|
| 115 |
+
with open(mod_label_statistic, 'r') as f:
|
| 116 |
+
statistic_dict = json.load(f)
|
| 117 |
+
|
| 118 |
+
# 将Label Merged DataFrame写入新的工作表
|
| 119 |
+
new_sheet = workbook.create_sheet(title='Label Merge', index=1)
|
| 120 |
+
new_sheet.cell(row=1, column=1, value='Modality_Label')
|
| 121 |
+
new_sheet.cell(row=1, column=2, value='Dice')
|
| 122 |
+
new_sheet.cell(row=1, column=3, value='NSD')
|
| 123 |
+
new_sheet.cell(row=1, column=4, value='Merge')
|
| 124 |
+
new_sheet.cell(row=1, column=5, value='Region')
|
| 125 |
+
new_sheet.cell(row=1, column=6, value='Total_Num')
|
| 126 |
+
new_sheet.cell(row=1, column=7, value='Aug_Ratio')
|
| 127 |
+
row = 2
|
| 128 |
+
for mod_lab, dice, nsd, merge, region in zip(mod_lab_ls, dice_ls, nsd_ls, merge_ls, region_ls):
|
| 129 |
+
if mod_lab in statistic_dict:
|
| 130 |
+
_, total_num, aug_ratio = statistic_dict[mod_lab]
|
| 131 |
+
else:
|
| 132 |
+
total_num = aug_ratio = 0
|
| 133 |
+
new_sheet.cell(row=row, column=1, value=mod_lab)
|
| 134 |
+
new_sheet.cell(row=row, column=2, value=dice)
|
| 135 |
+
new_sheet.cell(row=row, column=3, value=nsd)
|
| 136 |
+
new_sheet.cell(row=row, column=4, value=merge)
|
| 137 |
+
new_sheet.cell(row=row, column=5, value=region)
|
| 138 |
+
new_sheet.cell(row=row, column=6, value=total_num)
|
| 139 |
+
new_sheet.cell(row=row, column=7, value=aug_ratio)
|
| 140 |
+
row += 1
|
| 141 |
+
new_sheet.cell(row=row, column=2, value=sum(dice_ls)/len(dice_ls)) # avg over all labels
|
| 142 |
+
new_sheet.cell(row=row, column=3, value=sum(nsd_ls)/len(nsd_ls))
|
| 143 |
+
|
| 144 |
+
# 将Region Merged 写入新的工作表
|
| 145 |
+
new_sheet = workbook.create_sheet(title='Region Merge', index=1)
|
| 146 |
+
new_sheet.cell(row=1, column=1, value='Region')
|
| 147 |
+
new_sheet.cell(row=1, column=2, value='Dice')
|
| 148 |
+
new_sheet.cell(row=1, column=3, value='NSD')
|
| 149 |
+
new_sheet.cell(row=1, column=4, value='Merge')
|
| 150 |
+
row = 2
|
| 151 |
+
for key in region_dice_ls.keys():
|
| 152 |
+
if len(region_dice_ls[key]) == 0:
|
| 153 |
+
dice = nsd = 0
|
| 154 |
+
merge = None
|
| 155 |
+
else:
|
| 156 |
+
dice = sum(region_dice_ls[key])/len(region_dice_ls[key])
|
| 157 |
+
nsd = sum(region_nsd_ls[key])/len(region_nsd_ls[key])
|
| 158 |
+
merge = ','.join(region_merge_ls[key])
|
| 159 |
+
class_name = f'{key}({len(region_dice_ls[key])})'
|
| 160 |
+
new_sheet.cell(row=row, column=1, value=class_name)
|
| 161 |
+
new_sheet.cell(row=row, column=2, value=dice)
|
| 162 |
+
new_sheet.cell(row=row, column=3, value=nsd)
|
| 163 |
+
new_sheet.cell(row=row, column=4, value=merge)
|
| 164 |
+
row += 1
|
| 165 |
+
|
| 166 |
+
workbook.save(xlsx2save)
|
| 167 |
+
|
| 168 |
+
# 返回所有 label 的 avg
|
| 169 |
+
avg_dice_over_merged_labels = sum(dice_ls) / len(dice_ls)
|
| 170 |
+
avg_nsd_over_merged_labels = sum(nsd_ls) / len(nsd_ls)
|
| 171 |
+
|
| 172 |
+
return avg_dice_over_merged_labels, avg_nsd_over_merged_labels
|
| 173 |
+
|
| 174 |
+
if __name__ == '__main__':
|
| 175 |
+
import argparse
|
| 176 |
+
|
| 177 |
+
def str2bool(v):
|
| 178 |
+
if isinstance(v, bool):
|
| 179 |
+
return v
|
| 180 |
+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
| 181 |
+
return True
|
| 182 |
+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
| 183 |
+
return False
|
| 184 |
+
else:
|
| 185 |
+
raise argparse.ArgumentTypeError('Boolean value expected.')
|
| 186 |
+
|
| 187 |
+
parser = argparse.ArgumentParser()
|
| 188 |
+
parser.add_argument('--xlsx2load', type=str)
|
| 189 |
+
parser.add_argument('--xlsx2save', type=str)
|
| 190 |
+
parser.add_argument('--mod_lab_json', type=str, default='/mnt/petrelfs/share_data/wuchaoyi/SAM/processed_files_v4/mod_lab(72).json')
|
| 191 |
+
parser.add_argument('--mod_label_statistic', type=str, default='/mnt/petrelfs/share_data/wuchaoyi/SAM/processed_files_v4/mod_lab_accum_statis(49).json')
|
| 192 |
+
|
| 193 |
+
config = parser.parse_args()
|
| 194 |
+
|
| 195 |
+
if not config.xlsx2save:
|
| 196 |
+
config.xlsx2save = config.xlsx2load
|
| 197 |
+
|
| 198 |
+
merge(config.mod_lab_json, config.mod_label_statistic, config.xlsx2load, config.xlsx2save)
|
evaluate/metric.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import time
|
| 4 |
+
from medpy import metric
|
| 5 |
+
from .SurfaceDice import compute_surface_distances, compute_surface_dice_at_tolerance
|
| 6 |
+
|
| 7 |
+
def calculate_metric_percase(pred, gt, dice=True, nsd=True):
|
| 8 |
+
pred = pred.astype(bool)
|
| 9 |
+
gt = gt.astype(bool)
|
| 10 |
+
|
| 11 |
+
metrics = {}
|
| 12 |
+
|
| 13 |
+
if np.sum(gt) == 0.0:
|
| 14 |
+
if np.sum(pred) == 0.0:
|
| 15 |
+
if dice:
|
| 16 |
+
metrics['dice'] = 1.0
|
| 17 |
+
if nsd:
|
| 18 |
+
metrics['nsd'] = 1.0
|
| 19 |
+
else:
|
| 20 |
+
if dice:
|
| 21 |
+
metrics['dice'] = 0.0
|
| 22 |
+
if nsd:
|
| 23 |
+
metrics['nsd'] = 0.0
|
| 24 |
+
return metrics
|
| 25 |
+
|
| 26 |
+
if dice:
|
| 27 |
+
dice_score = metric.binary.dc(pred, gt)
|
| 28 |
+
metrics['dice'] = dice_score
|
| 29 |
+
|
| 30 |
+
if nsd:
|
| 31 |
+
surface_distances = compute_surface_distances(gt, pred, [1, 1, 3])
|
| 32 |
+
nsd_score = compute_surface_dice_at_tolerance(surface_distances, 1)
|
| 33 |
+
metrics['nsd'] = nsd_score
|
| 34 |
+
|
| 35 |
+
return metrics
|
| 36 |
+
|
| 37 |
+
if __name__ == '__main__':
|
| 38 |
+
pred = torch.zeros((3, 256, 256, 16)).numpy()
|
| 39 |
+
pred[:, 0:128, 0:128, :] = 1.0
|
| 40 |
+
gt = torch.zeros((3, 256, 256, 16)).numpy()
|
| 41 |
+
gt[:, 0:64, 0:64, :] = 1.0
|
| 42 |
+
dice = calculate_metric_percase(pred, gt)['dice']
|
| 43 |
+
print(dice)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
|
evaluate/params.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
def str2bool(v):
|
| 4 |
+
return v.lower() in ('true', 't')
|
| 5 |
+
|
| 6 |
+
def parse_args():
|
| 7 |
+
parser = argparse.ArgumentParser()
|
| 8 |
+
|
| 9 |
+
# Exp Controller
|
| 10 |
+
|
| 11 |
+
parser.add_argument(
|
| 12 |
+
"--rcd_dir",
|
| 13 |
+
type=str,
|
| 14 |
+
help="save the evaluation results (in a directory)",
|
| 15 |
+
)
|
| 16 |
+
parser.add_argument(
|
| 17 |
+
"--rcd_file",
|
| 18 |
+
type=str,
|
| 19 |
+
help="save the evaluation results (in a csv/xlsx file)",
|
| 20 |
+
)
|
| 21 |
+
parser.add_argument(
|
| 22 |
+
"--visualization",
|
| 23 |
+
type=str2bool,
|
| 24 |
+
default=False,
|
| 25 |
+
help="save the visualization for each case (img, gt, pred)",
|
| 26 |
+
)
|
| 27 |
+
parser.add_argument(
|
| 28 |
+
"--checkpoint",
|
| 29 |
+
type=str,
|
| 30 |
+
help="Checkpoint path",
|
| 31 |
+
)
|
| 32 |
+
parser.add_argument(
|
| 33 |
+
"--partial_load",
|
| 34 |
+
type=str2bool,
|
| 35 |
+
default=True,
|
| 36 |
+
help="Allow to load partial paramters from checkpoint",
|
| 37 |
+
)
|
| 38 |
+
parser.add_argument(
|
| 39 |
+
"--gpu",
|
| 40 |
+
type=str,
|
| 41 |
+
default=None,
|
| 42 |
+
)
|
| 43 |
+
parser.add_argument(
|
| 44 |
+
"--resume",
|
| 45 |
+
type=str2bool,
|
| 46 |
+
default=True,
|
| 47 |
+
help="Inherit medial results from an interrupted evaluation (no harm even if you evaluate from scratch)",
|
| 48 |
+
)
|
| 49 |
+
parser.add_argument(
|
| 50 |
+
"--save_interval",
|
| 51 |
+
type=int,
|
| 52 |
+
default=100
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
# Metrics
|
| 56 |
+
|
| 57 |
+
parser.add_argument(
|
| 58 |
+
"--dice",
|
| 59 |
+
type=str2bool,
|
| 60 |
+
default=True,
|
| 61 |
+
)
|
| 62 |
+
parser.add_argument(
|
| 63 |
+
"--nsd",
|
| 64 |
+
type=str2bool,
|
| 65 |
+
default=True,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
# Med SAM Dataset
|
| 69 |
+
|
| 70 |
+
parser.add_argument(
|
| 71 |
+
"--datasets_jsonl",
|
| 72 |
+
type=str,
|
| 73 |
+
)
|
| 74 |
+
parser.add_argument(
|
| 75 |
+
"--text_prompts_json",
|
| 76 |
+
type=str,
|
| 77 |
+
help='This is needed for CVPR25 challenge, where multiple prompts (synonyms) are required.'
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
# Sampler and Loader
|
| 81 |
+
|
| 82 |
+
parser.add_argument(
|
| 83 |
+
"--online_crop",
|
| 84 |
+
type=str2bool,
|
| 85 |
+
default='False',
|
| 86 |
+
help='load pre-cropped image patches directly, or crop online',
|
| 87 |
+
)
|
| 88 |
+
parser.add_argument(
|
| 89 |
+
"--crop_size",
|
| 90 |
+
type=int,
|
| 91 |
+
nargs='+',
|
| 92 |
+
default=[288, 288, 96],
|
| 93 |
+
)
|
| 94 |
+
parser.add_argument(
|
| 95 |
+
"--max_queries",
|
| 96 |
+
type=int,
|
| 97 |
+
default=256,
|
| 98 |
+
)
|
| 99 |
+
parser.add_argument(
|
| 100 |
+
"--batchsize_3d",
|
| 101 |
+
type=int,
|
| 102 |
+
default=2,
|
| 103 |
+
)
|
| 104 |
+
parser.add_argument(
|
| 105 |
+
"--pin_memory",
|
| 106 |
+
type=str2bool,
|
| 107 |
+
default=False,
|
| 108 |
+
help='load data to gpu to accelerate'
|
| 109 |
+
)
|
| 110 |
+
parser.add_argument(
|
| 111 |
+
"--num_workers",
|
| 112 |
+
type=int,
|
| 113 |
+
default=4
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
# Knowledge Encoder
|
| 117 |
+
parser.add_argument(
|
| 118 |
+
"--text_encoder_partial_load",
|
| 119 |
+
type=str2bool,
|
| 120 |
+
default=True,
|
| 121 |
+
help="Allow to load partial paramters from checkpoint",
|
| 122 |
+
)
|
| 123 |
+
parser.add_argument(
|
| 124 |
+
"--text_encoder_checkpoint",
|
| 125 |
+
type=str,
|
| 126 |
+
)
|
| 127 |
+
parser.add_argument(
|
| 128 |
+
"--text_encoder",
|
| 129 |
+
type=str,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# MaskFormer
|
| 133 |
+
|
| 134 |
+
parser.add_argument(
|
| 135 |
+
"--vision_backbone",
|
| 136 |
+
type=str,
|
| 137 |
+
help='UNET or UNET-H'
|
| 138 |
+
)
|
| 139 |
+
parser.add_argument(
|
| 140 |
+
"--patch_size",
|
| 141 |
+
type=int,
|
| 142 |
+
nargs='+',
|
| 143 |
+
default=[32, 32, 32],
|
| 144 |
+
help='patch size on h w and d'
|
| 145 |
+
)
|
| 146 |
+
parser.add_argument(
|
| 147 |
+
"--deep_supervision",
|
| 148 |
+
type=str2bool,
|
| 149 |
+
default=False,
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
args = parser.parse_args()
|
| 153 |
+
return args
|
inference_medals_nifti.py
ADDED
|
@@ -0,0 +1,1885 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Medal-S inference script for generic raw image segmentation.
|
| 3 |
+
|
| 4 |
+
This script provides an interface for running Medal-S inference
|
| 5 |
+
on raw NIfTI images. It supports both single-stage (Stage 2 only) and
|
| 6 |
+
two-stage (Stage 1 + Stage 2) inference modes.
|
| 7 |
+
|
| 8 |
+
Usage:
|
| 9 |
+
python inference_medals.py --input input.nii.gz --output output.nii.gz \\
|
| 10 |
+
--modality CT --texts "Aorta observed in abdominal CT scans" --labels 1
|
| 11 |
+
|
| 12 |
+
# Or use JSON configuration file:
|
| 13 |
+
python inference_medals.py --input input.nii.gz --output output.nii.gz \\
|
| 14 |
+
--config config.json --mode stage1+stage2
|
| 15 |
+
|
| 16 |
+
Author: Pengcheng Shi
|
| 17 |
+
Institute: Medical Image Insights, Inc., Shanghai, China
|
| 18 |
+
Email: shipc1220@gmail.com
|
| 19 |
+
License: Apache License 2.0
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import os
|
| 23 |
+
import argparse
|
| 24 |
+
import json
|
| 25 |
+
import time
|
| 26 |
+
import math
|
| 27 |
+
import random
|
| 28 |
+
import itertools
|
| 29 |
+
import gc
|
| 30 |
+
import numpy as np
|
| 31 |
+
import SimpleITK as sitk
|
| 32 |
+
import torch
|
| 33 |
+
import torch.nn.functional as F
|
| 34 |
+
from typing import List
|
| 35 |
+
from scipy.ndimage import label, gaussian_filter
|
| 36 |
+
from einops import rearrange
|
| 37 |
+
from tqdm import tqdm
|
| 38 |
+
from torch.cuda.amp import autocast
|
| 39 |
+
|
| 40 |
+
from data.default_resampling import resample_data_or_seg, compute_new_shape, resample_data_or_seg_to_spacing
|
| 41 |
+
from data.resample_torch import resample_torch_fornnunet, resample_torch_simple
|
| 42 |
+
from model.maskformer import Maskformer
|
| 43 |
+
from model.knowledge_encoder import Knowledge_Encoder
|
| 44 |
+
|
| 45 |
+
def adjust_spacing(img_array, img_spacing):
|
| 46 |
+
"""
|
| 47 |
+
Adjust spacing based on image dimensions.
|
| 48 |
+
|
| 49 |
+
This function swaps spacing values if the dimension with minimum size
|
| 50 |
+
doesn't match the dimension with maximum spacing.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
img_array: Image array (used for shape reference)
|
| 54 |
+
img_spacing: Spacing array
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
Adjusted spacing array
|
| 58 |
+
"""
|
| 59 |
+
img_spacing = np.asarray(img_spacing)
|
| 60 |
+
min_dim_index = np.argmin(img_array.shape)
|
| 61 |
+
max_spacing_index = np.argmax(img_spacing)
|
| 62 |
+
|
| 63 |
+
if (min_dim_index != max_spacing_index) and (img_spacing[max_spacing_index] > 0.5):
|
| 64 |
+
new_order = list(range(len(img_spacing)))
|
| 65 |
+
new_order[min_dim_index], new_order[max_spacing_index] = new_order[max_spacing_index], new_order[min_dim_index]
|
| 66 |
+
img_spacing = img_spacing[new_order]
|
| 67 |
+
|
| 68 |
+
return img_spacing
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def remove_small_objects_binary(binary_data, min_size=10):
|
| 72 |
+
"""
|
| 73 |
+
Remove small objects from binary data.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
binary_data: Binary array
|
| 77 |
+
min_size: Minimum size threshold for objects to keep
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
Binary array with small objects removed
|
| 81 |
+
"""
|
| 82 |
+
labeled_array, num_features = label(binary_data)
|
| 83 |
+
sizes = np.bincount(labeled_array.ravel())
|
| 84 |
+
remove = sizes < min_size
|
| 85 |
+
remove[0] = False # Ensure the background (label 0) is not removed
|
| 86 |
+
labeled_array[remove[labeled_array]] = 0
|
| 87 |
+
return labeled_array > 0
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def respace_image(image: np.ndarray, current_spacing: np.ndarray, target_spacing: np.ndarray, device: torch.device) -> np.ndarray:
|
| 91 |
+
"""
|
| 92 |
+
Resample image to target spacing.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
image: Input image array with shape (C, H, W, D)
|
| 96 |
+
current_spacing: Current spacing array
|
| 97 |
+
target_spacing: Target spacing array
|
| 98 |
+
device: PyTorch device for resampling
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
Resampled image array
|
| 102 |
+
"""
|
| 103 |
+
new_shape = compute_new_shape(image.shape[1:], current_spacing, target_spacing)
|
| 104 |
+
resampled_image = resample_torch_fornnunet(
|
| 105 |
+
image, new_shape, current_spacing, target_spacing,
|
| 106 |
+
is_seg=False, num_threads=8, device=device,
|
| 107 |
+
memefficient_seg_resampling=False,
|
| 108 |
+
force_separate_z=None,
|
| 109 |
+
separate_z_anisotropy_threshold=3.0
|
| 110 |
+
)
|
| 111 |
+
return resampled_image
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def respace_mask(mask: np.ndarray, current_spacing: np.ndarray, target_spacing: np.ndarray, device: torch.device) -> np.ndarray:
|
| 115 |
+
"""
|
| 116 |
+
Resample mask to target spacing.
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
mask: Input mask array with shape (C, H, W, D)
|
| 120 |
+
current_spacing: Current spacing array
|
| 121 |
+
target_spacing: Target spacing array
|
| 122 |
+
device: PyTorch device for resampling
|
| 123 |
+
|
| 124 |
+
Returns:
|
| 125 |
+
Resampled mask array
|
| 126 |
+
"""
|
| 127 |
+
new_shape = compute_new_shape(mask.shape[1:], current_spacing, target_spacing)
|
| 128 |
+
resampled_mask = resample_torch_fornnunet(
|
| 129 |
+
mask, new_shape, current_spacing, target_spacing,
|
| 130 |
+
is_seg=True, num_threads=8, device=device,
|
| 131 |
+
memefficient_seg_resampling=False,
|
| 132 |
+
force_separate_z=None,
|
| 133 |
+
separate_z_anisotropy_threshold=3.0
|
| 134 |
+
)
|
| 135 |
+
return resampled_mask
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def split_3d(image_tensor, crop_size=[288, 288, 96]):
|
| 139 |
+
"""
|
| 140 |
+
Split 3D image into overlapping patches.
|
| 141 |
+
|
| 142 |
+
Patches are extracted with 50% overlap (stride = crop_size / 2) to ensure
|
| 143 |
+
complete coverage of the image volume.
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
image_tensor: Input image tensor with shape (C, H, W, D)
|
| 147 |
+
crop_size: Size of each patch [h, w, d]
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
split_patch: List of patch tensors
|
| 151 |
+
split_idx: List of patch indices [h_s, h_e, w_s, w_e, d_s, d_e]
|
| 152 |
+
"""
|
| 153 |
+
interval_h, interval_w, interval_d = crop_size[0] // 2, crop_size[1] // 2, crop_size[2] // 2
|
| 154 |
+
split_idx = []
|
| 155 |
+
split_patch = []
|
| 156 |
+
|
| 157 |
+
c, h, w, d = image_tensor.shape
|
| 158 |
+
h_crop = max(math.ceil(h / interval_h) - 1, 1)
|
| 159 |
+
w_crop = max(math.ceil(w / interval_w) - 1, 1)
|
| 160 |
+
d_crop = max(math.ceil(d / interval_d) - 1, 1)
|
| 161 |
+
|
| 162 |
+
for i in range(h_crop):
|
| 163 |
+
h_s = i * interval_h
|
| 164 |
+
h_e = h_s + crop_size[0]
|
| 165 |
+
if h_e > h:
|
| 166 |
+
h_s = h - crop_size[0]
|
| 167 |
+
h_e = h
|
| 168 |
+
if h_s < 0:
|
| 169 |
+
h_s = 0
|
| 170 |
+
for j in range(w_crop):
|
| 171 |
+
w_s = j * interval_w
|
| 172 |
+
w_e = w_s + crop_size[1]
|
| 173 |
+
if w_e > w:
|
| 174 |
+
w_s = w - crop_size[1]
|
| 175 |
+
w_e = w
|
| 176 |
+
if w_s < 0:
|
| 177 |
+
w_s = 0
|
| 178 |
+
for k in range(d_crop):
|
| 179 |
+
d_s = k * interval_d
|
| 180 |
+
d_e = d_s + crop_size[2]
|
| 181 |
+
if d_e > d:
|
| 182 |
+
d_s = d - crop_size[2]
|
| 183 |
+
d_e = d
|
| 184 |
+
if d_s < 0:
|
| 185 |
+
d_s = 0
|
| 186 |
+
split_idx.append([h_s, h_e, w_s, w_e, d_s, d_e])
|
| 187 |
+
split_patch.append(image_tensor[:, h_s:h_e, w_s:w_e, d_s:d_e])
|
| 188 |
+
|
| 189 |
+
return split_patch, split_idx
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def pad_if_necessary(image, crop_size=[288, 288, 96]):
|
| 193 |
+
"""
|
| 194 |
+
Pad image if necessary to meet crop size requirements.
|
| 195 |
+
|
| 196 |
+
Args:
|
| 197 |
+
image: Input image tensor with shape (C, H, W, D)
|
| 198 |
+
crop_size: Minimum size requirements [h, w, d]
|
| 199 |
+
|
| 200 |
+
Returns:
|
| 201 |
+
padded_image: Padded image tensor
|
| 202 |
+
padding_info: Tuple of padding amounts (pad_h, pad_w, pad_d)
|
| 203 |
+
"""
|
| 204 |
+
c, h, w, d = image.shape
|
| 205 |
+
croph, cropw, cropd = crop_size
|
| 206 |
+
pad_in_h = 0 if h >= croph else croph - h
|
| 207 |
+
pad_in_w = 0 if w >= cropw else cropw - w
|
| 208 |
+
pad_in_d = 0 if d >= cropd else cropd - d
|
| 209 |
+
|
| 210 |
+
padding_info = (pad_in_h, pad_in_w, pad_in_d)
|
| 211 |
+
|
| 212 |
+
if pad_in_h + pad_in_w + pad_in_d > 0:
|
| 213 |
+
pad = (0, pad_in_d, 0, pad_in_w, 0, pad_in_h)
|
| 214 |
+
image = F.pad(image, pad, 'constant', 0)
|
| 215 |
+
|
| 216 |
+
return image, padding_info
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def remove_padding(padded_image, padding_info):
|
| 220 |
+
"""
|
| 221 |
+
Remove padding from image.
|
| 222 |
+
|
| 223 |
+
Args:
|
| 224 |
+
padded_image: Padded image (can be torch.Tensor or numpy array)
|
| 225 |
+
padding_info: Tuple of padding amounts (pad_h, pad_w, pad_d)
|
| 226 |
+
|
| 227 |
+
Returns:
|
| 228 |
+
Image with padding removed
|
| 229 |
+
"""
|
| 230 |
+
pad_in_h, pad_in_w, pad_in_d = padding_info
|
| 231 |
+
|
| 232 |
+
if len(padded_image.shape) == 4:
|
| 233 |
+
if isinstance(padded_image, torch.Tensor):
|
| 234 |
+
return padded_image[:, :padded_image.shape[1]-pad_in_h, :padded_image.shape[2]-pad_in_w, :padded_image.shape[3]-pad_in_d]
|
| 235 |
+
else:
|
| 236 |
+
return padded_image[:, :padded_image.shape[1]-pad_in_h, :padded_image.shape[2]-pad_in_w, :padded_image.shape[3]-pad_in_d]
|
| 237 |
+
else:
|
| 238 |
+
if isinstance(padded_image, torch.Tensor):
|
| 239 |
+
return padded_image[:padded_image.shape[0]-pad_in_h, :padded_image.shape[1]-pad_in_w, :padded_image.shape[2]-pad_in_d]
|
| 240 |
+
else:
|
| 241 |
+
return padded_image[:padded_image.shape[0]-pad_in_h, :padded_image.shape[1]-pad_in_w, :padded_image.shape[2]-pad_in_d]
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def internal_maybe_mirror_and_predict(model=None, queries=None, image_input=None, simulated_lowres_sc_pred=None,
|
| 245 |
+
simulated_lowres_mc_pred=None, mirror_axes=(0, 1, 2)):
|
| 246 |
+
"""
|
| 247 |
+
Apply test-time augmentation with mirroring.
|
| 248 |
+
|
| 249 |
+
This function performs inference with multiple mirroring combinations
|
| 250 |
+
and averages the results for improved robustness.
|
| 251 |
+
|
| 252 |
+
Args:
|
| 253 |
+
model: Model to use for prediction
|
| 254 |
+
queries: Query tensor
|
| 255 |
+
image_input: Input image tensor
|
| 256 |
+
simulated_lowres_sc_pred: Simulated low-res single-channel prediction
|
| 257 |
+
simulated_lowres_mc_pred: Simulated low-res multi-channel prediction
|
| 258 |
+
mirror_axes: Axes to mirror (0, 1, 2 for spatial dimensions)
|
| 259 |
+
|
| 260 |
+
Returns:
|
| 261 |
+
Averaged prediction tensor
|
| 262 |
+
"""
|
| 263 |
+
prediction = model(queries=queries,
|
| 264 |
+
image_input=image_input,
|
| 265 |
+
simulated_lowres_sc_pred=simulated_lowres_sc_pred,
|
| 266 |
+
simulated_lowres_mc_pred=simulated_lowres_mc_pred,
|
| 267 |
+
train_mode=False)
|
| 268 |
+
|
| 269 |
+
if mirror_axes is not None:
|
| 270 |
+
assert max(mirror_axes) <= image_input.ndim - 3, 'mirror_axes does not match the dimension of the input!'
|
| 271 |
+
mirror_axes = [m + 2 for m in mirror_axes]
|
| 272 |
+
axes_combinations = [
|
| 273 |
+
c for i in range(len(mirror_axes)) for c in itertools.combinations(mirror_axes, i + 1)
|
| 274 |
+
]
|
| 275 |
+
for axes in axes_combinations:
|
| 276 |
+
image_input_fliped = torch.flip(image_input, axes)
|
| 277 |
+
simulated_lowres_sc_pred_fliped = torch.flip(simulated_lowres_sc_pred.unsqueeze(0), axes).squeeze(0) if simulated_lowres_sc_pred is not None else None
|
| 278 |
+
simulated_lowres_mc_pred_fliped = torch.flip(simulated_lowres_mc_pred.unsqueeze(0), axes).squeeze(0) if simulated_lowres_mc_pred is not None else None
|
| 279 |
+
prediction_fliped = model(queries=queries,
|
| 280 |
+
image_input=image_input_fliped,
|
| 281 |
+
simulated_lowres_sc_pred=simulated_lowres_sc_pred_fliped,
|
| 282 |
+
simulated_lowres_mc_pred=simulated_lowres_mc_pred_fliped,
|
| 283 |
+
train_mode=False)
|
| 284 |
+
prediction += torch.flip(prediction_fliped, axes)
|
| 285 |
+
prediction /= (len(axes_combinations) + 1)
|
| 286 |
+
return prediction
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def compute_patch_prediction(
|
| 290 |
+
queries: torch.Tensor,
|
| 291 |
+
patches: torch.Tensor,
|
| 292 |
+
lowres_single_channel_pred: torch.Tensor,
|
| 293 |
+
lowres_multi_channel_pred: torch.Tensor,
|
| 294 |
+
model: torch.nn.Module,
|
| 295 |
+
possible_block_sizes: List[int],
|
| 296 |
+
n_repeats: int = 1,
|
| 297 |
+
disable_tta: bool = True
|
| 298 |
+
) -> torch.Tensor:
|
| 299 |
+
"""
|
| 300 |
+
Compute patch predictions using complementary masking.
|
| 301 |
+
|
| 302 |
+
This function splits the volume into blocks, processes complementary halves
|
| 303 |
+
using random masks, and combines results. The process is repeated n_repeats
|
| 304 |
+
times with different random masks, and results are averaged.
|
| 305 |
+
|
| 306 |
+
Args:
|
| 307 |
+
queries: Input query tensor, shape (batch, query_dim)
|
| 308 |
+
patches: Image patch tensor, shape (batch, channels, h, w, d)
|
| 309 |
+
lowres_single_channel_pred: Low-res single-channel prediction, shape (1, 1, h, w, d)
|
| 310 |
+
lowres_multi_channel_pred: Low-res multi-channel prediction, shape (1, c, h, w, d)
|
| 311 |
+
model: Trained neural network model
|
| 312 |
+
possible_block_sizes: List of possible block sizes (e.g., [8, 16, 32])
|
| 313 |
+
n_repeats: Number of times to repeat prediction with different masks
|
| 314 |
+
disable_tta: Whether to disable test-time augmentation
|
| 315 |
+
|
| 316 |
+
Returns:
|
| 317 |
+
Averaged patch prediction, shape (1, c, h, w, d)
|
| 318 |
+
"""
|
| 319 |
+
# Validate inputs
|
| 320 |
+
if not possible_block_sizes:
|
| 321 |
+
raise ValueError("possible_block_sizes cannot be empty")
|
| 322 |
+
if n_repeats < 1:
|
| 323 |
+
raise ValueError("n_repeats must be at least 1")
|
| 324 |
+
|
| 325 |
+
_, _, h, w, d = lowres_single_channel_pred.shape
|
| 326 |
+
device = lowres_single_channel_pred.device
|
| 327 |
+
prediction_sum = torch.zeros_like(lowres_multi_channel_pred, device=device)
|
| 328 |
+
|
| 329 |
+
def upsample_block_mask(block_mask: torch.Tensor, block_size: int) -> torch.Tensor:
|
| 330 |
+
"""Upsample a block mask to full resolution."""
|
| 331 |
+
upsampled = (
|
| 332 |
+
block_mask.unsqueeze(0).unsqueeze(0)
|
| 333 |
+
.repeat_interleave(block_size, dim=2)
|
| 334 |
+
.repeat_interleave(block_size, dim=3)
|
| 335 |
+
.repeat_interleave(block_size, dim=4)
|
| 336 |
+
[:, :, :h, :w, :d]
|
| 337 |
+
).float()
|
| 338 |
+
return upsampled
|
| 339 |
+
|
| 340 |
+
for _ in range(n_repeats):
|
| 341 |
+
block_size = random.choice(possible_block_sizes)
|
| 342 |
+
n_blocks_h = (h + block_size - 1) // block_size
|
| 343 |
+
n_blocks_w = (w + block_size - 1) // block_size
|
| 344 |
+
n_blocks_d = (d + block_size - 1) // block_size
|
| 345 |
+
total_blocks = n_blocks_h * n_blocks_w * n_blocks_d
|
| 346 |
+
|
| 347 |
+
num_selected = max(1, total_blocks // 2)
|
| 348 |
+
block_mask = torch.zeros(n_blocks_h, n_blocks_w, n_blocks_d, dtype=torch.bool, device=device)
|
| 349 |
+
indices = torch.randperm(total_blocks, device=device)[:num_selected]
|
| 350 |
+
block_mask.view(-1)[indices] = True
|
| 351 |
+
|
| 352 |
+
mask = upsample_block_mask(block_mask, block_size)
|
| 353 |
+
complementary_mask = 1.0 - mask
|
| 354 |
+
|
| 355 |
+
masked_sc_pred = lowres_single_channel_pred * mask
|
| 356 |
+
masked_mc_pred = lowres_multi_channel_pred * mask
|
| 357 |
+
|
| 358 |
+
if disable_tta:
|
| 359 |
+
first_half_pred = model(
|
| 360 |
+
queries=queries,
|
| 361 |
+
image_input=patches,
|
| 362 |
+
simulated_lowres_sc_pred=masked_sc_pred,
|
| 363 |
+
simulated_lowres_mc_pred=masked_mc_pred,
|
| 364 |
+
train_mode=False
|
| 365 |
+
)
|
| 366 |
+
else:
|
| 367 |
+
first_half_pred = internal_maybe_mirror_and_predict(
|
| 368 |
+
model=model,
|
| 369 |
+
queries=queries,
|
| 370 |
+
image_input=patches,
|
| 371 |
+
simulated_lowres_sc_pred=masked_sc_pred,
|
| 372 |
+
simulated_lowres_mc_pred=masked_mc_pred,
|
| 373 |
+
mirror_axes=(0, 1, 2)
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
masked_sc_pred_comp = lowres_single_channel_pred * complementary_mask
|
| 377 |
+
masked_mc_pred_comp = lowres_multi_channel_pred * complementary_mask
|
| 378 |
+
|
| 379 |
+
if disable_tta:
|
| 380 |
+
second_half_pred = model(
|
| 381 |
+
queries=queries,
|
| 382 |
+
image_input=patches,
|
| 383 |
+
simulated_lowres_sc_pred=masked_sc_pred_comp,
|
| 384 |
+
simulated_lowres_mc_pred=masked_mc_pred_comp,
|
| 385 |
+
train_mode=False
|
| 386 |
+
)
|
| 387 |
+
else:
|
| 388 |
+
second_half_pred = internal_maybe_mirror_and_predict(
|
| 389 |
+
model=model,
|
| 390 |
+
queries=queries,
|
| 391 |
+
image_input=patches,
|
| 392 |
+
simulated_lowres_sc_pred=masked_sc_pred_comp,
|
| 393 |
+
simulated_lowres_mc_pred=masked_mc_pred_comp,
|
| 394 |
+
mirror_axes=(0, 1, 2)
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
final_prediction = first_half_pred * complementary_mask + second_half_pred * mask
|
| 398 |
+
prediction_sum += final_prediction
|
| 399 |
+
|
| 400 |
+
return prediction_sum / n_repeats
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
def read_npz_data(raw_image, raw_spacing, crop_size=[288, 288, 96],
|
| 404 |
+
target_spacing=[1.5, 1.5, 3.0], scaled_roi_lowres_pred_array=None,
|
| 405 |
+
class_name_list=[], stage_1_flag=False, device=torch.device("cuda", 0), verbose=True):
|
| 406 |
+
"""
|
| 407 |
+
Read and preprocess image data for inference.
|
| 408 |
+
|
| 409 |
+
This function handles spacing adjustments, image resampling, padding,
|
| 410 |
+
and patch splitting for the inference pipeline.
|
| 411 |
+
|
| 412 |
+
Args:
|
| 413 |
+
raw_image: Input image array with shape (d, h, w)
|
| 414 |
+
raw_spacing: Spacing array with shape (3,)
|
| 415 |
+
crop_size: Target crop size [h, w, d]
|
| 416 |
+
target_spacing: Target spacing [h, w, d]
|
| 417 |
+
scaled_roi_lowres_pred_array: Optional low-res prediction for ROI-based inference
|
| 418 |
+
class_name_list: List of class names (kept for compatibility, not used)
|
| 419 |
+
stage_1_flag: Whether this is Stage 1 inference (kept for compatibility, not used)
|
| 420 |
+
device: PyTorch device for resampling
|
| 421 |
+
verbose: Whether to print detailed information (default: True)
|
| 422 |
+
|
| 423 |
+
Returns:
|
| 424 |
+
data_dict: Dictionary containing preprocessed patches and metadata
|
| 425 |
+
"""
|
| 426 |
+
raw_d, raw_h, raw_w = raw_image.shape
|
| 427 |
+
image = rearrange(raw_image, 'd h w -> h w d')
|
| 428 |
+
spacing = raw_spacing.astype(np.float32)
|
| 429 |
+
|
| 430 |
+
# Simplified spacing adjustment following the provided steps
|
| 431 |
+
# Step 1: Handle very small spacing values
|
| 432 |
+
for i in range(3):
|
| 433 |
+
if spacing[i] <= 0.1:
|
| 434 |
+
spacing[i] = 1.0
|
| 435 |
+
|
| 436 |
+
# Step 2: Adjust spacing based on image dimensions
|
| 437 |
+
spacing = adjust_spacing(image, spacing)
|
| 438 |
+
|
| 439 |
+
# Step 3: Initialize parameters for spacing adjustment
|
| 440 |
+
max_dims = [1000, 1000, 700]
|
| 441 |
+
min_dims = crop_size
|
| 442 |
+
thresholds = []
|
| 443 |
+
current = 1.25
|
| 444 |
+
while current <= 50:
|
| 445 |
+
thresholds.append(current)
|
| 446 |
+
current *= 1.25
|
| 447 |
+
raw_target_spacing = target_spacing.copy()
|
| 448 |
+
|
| 449 |
+
# Step 4: Adjust spacing based on constraints
|
| 450 |
+
for i in range(3):
|
| 451 |
+
# If spacing is less than 1.0 and image dimension is within max_dims, set to 1.0
|
| 452 |
+
if spacing[i] < 1.0 and image.shape[i] <= max_dims[i]:
|
| 453 |
+
spacing[i] = 1.0 # second stage model resolution
|
| 454 |
+
|
| 455 |
+
# If physical dimension exceeds max_dims and spacing is greater than target, use target spacing
|
| 456 |
+
if spacing[i] * image.shape[i] > max_dims[i] * target_spacing[i] and spacing[i] > target_spacing[i]:
|
| 457 |
+
spacing[i] = target_spacing[i]
|
| 458 |
+
# If physical dimension is less than min_dims threshold, adjust target_spacing
|
| 459 |
+
elif spacing[i] * image.shape[i] < min_dims[i] * target_spacing[i]:
|
| 460 |
+
alpha_spacing = 1
|
| 461 |
+
for threshold in reversed(thresholds):
|
| 462 |
+
if image.shape[i] <= (min_dims[i] / threshold):
|
| 463 |
+
alpha_spacing = threshold
|
| 464 |
+
break
|
| 465 |
+
|
| 466 |
+
raw_target_spacing[i] = target_spacing[i]
|
| 467 |
+
target_spacing[i] = max(spacing[i] * image.shape[i] / min_dims[i], spacing[i] / alpha_spacing)
|
| 468 |
+
if verbose:
|
| 469 |
+
print("alpha_spacing: ", alpha_spacing)
|
| 470 |
+
print("spacing[i] * image.shape[i] / min_dims[i], spacing[i] / alpha_spacing: ", spacing[i] * image.shape[i] / min_dims[i], spacing[i] / alpha_spacing)
|
| 471 |
+
print("raw_target_spacing[i], target_spacing[i]: ", raw_target_spacing[i], target_spacing[i])
|
| 472 |
+
target_spacing[i] = min(raw_target_spacing[i], target_spacing[i])
|
| 473 |
+
if verbose:
|
| 474 |
+
print("image.shape[i], min_dims[i], target_spacing[i], spacing[i]: ", image.shape[i], min_dims[i], target_spacing[i], spacing[i])
|
| 475 |
+
|
| 476 |
+
# Set default num_iterations (no special class handling)
|
| 477 |
+
num_iterations = 1
|
| 478 |
+
|
| 479 |
+
image = image[np.newaxis, ...].astype(np.float32)
|
| 480 |
+
if verbose:
|
| 481 |
+
print("image.shape: ", image.shape)
|
| 482 |
+
print("spacing: ", spacing)
|
| 483 |
+
print("target_spacing: ", target_spacing)
|
| 484 |
+
image = respace_image(image, spacing, target_spacing, torch.device('cpu'))
|
| 485 |
+
if verbose:
|
| 486 |
+
print("respace image.shape: ", image.shape)
|
| 487 |
+
image = torch.tensor(image)
|
| 488 |
+
image, padding_info = pad_if_necessary(image, crop_size=crop_size)
|
| 489 |
+
_, h, w, d = image.shape
|
| 490 |
+
|
| 491 |
+
patches, y1y2_x1x2_z1z2_ls = split_3d(image, crop_size=crop_size)
|
| 492 |
+
|
| 493 |
+
data_dict = {
|
| 494 |
+
'spacing': spacing,
|
| 495 |
+
'original_shape': (raw_h, raw_w, raw_d),
|
| 496 |
+
'current_shape': (h, w, d),
|
| 497 |
+
'patches': patches,
|
| 498 |
+
'y1y2_x1x2_z1z2_ls': y1y2_x1x2_z1z2_ls,
|
| 499 |
+
'padding_info': padding_info,
|
| 500 |
+
'raw_image': raw_image,
|
| 501 |
+
'num_iterations': num_iterations
|
| 502 |
+
}
|
| 503 |
+
|
| 504 |
+
if scaled_roi_lowres_pred_array is not None:
|
| 505 |
+
lowres_pred = rearrange(scaled_roi_lowres_pred_array, 'd h w -> h w d')
|
| 506 |
+
lowres_pred = lowres_pred[np.newaxis, ...].astype(np.float32)
|
| 507 |
+
lowres_pred = respace_mask(lowres_pred, spacing, target_spacing, torch.device('cpu'))
|
| 508 |
+
lowres_pred = torch.tensor(lowres_pred)
|
| 509 |
+
lowres_pred, padding_info = pad_if_necessary(lowres_pred, crop_size=crop_size)
|
| 510 |
+
lowres_pred_patches, _ = split_3d(lowres_pred, crop_size=crop_size)
|
| 511 |
+
data_dict['lowres_pred_patches'] = lowres_pred_patches
|
| 512 |
+
data_dict['padding_info'] = padding_info
|
| 513 |
+
|
| 514 |
+
return data_dict
|
| 515 |
+
|
| 516 |
+
|
| 517 |
+
def compute_gaussian(tile_size, sigma_scale: float = 1. / 8, value_scaling_factor: float = 10, dtype=np.float16):
|
| 518 |
+
"""
|
| 519 |
+
Compute Gaussian importance map for patch weighting.
|
| 520 |
+
|
| 521 |
+
This creates a Gaussian weight map centered at the patch center, used for
|
| 522 |
+
weighted averaging of overlapping patch predictions.
|
| 523 |
+
|
| 524 |
+
Args:
|
| 525 |
+
tile_size: Size of the tile (crop_size)
|
| 526 |
+
sigma_scale: Scale factor for Gaussian sigma (relative to tile size)
|
| 527 |
+
value_scaling_factor: Scaling factor for the Gaussian values
|
| 528 |
+
dtype: Data type for the output array
|
| 529 |
+
|
| 530 |
+
Returns:
|
| 531 |
+
Gaussian importance map array
|
| 532 |
+
"""
|
| 533 |
+
tmp = np.zeros(tile_size)
|
| 534 |
+
center_coords = [i // 2 for i in tile_size]
|
| 535 |
+
sigmas = [i * sigma_scale for i in tile_size]
|
| 536 |
+
tmp[tuple(center_coords)] = 1
|
| 537 |
+
gaussian_importance_map = gaussian_filter(tmp, sigmas, 0, mode='constant', cval=0)
|
| 538 |
+
gaussian_importance_map = gaussian_importance_map / np.max(gaussian_importance_map) * value_scaling_factor
|
| 539 |
+
gaussian_importance_map = gaussian_importance_map.astype(dtype)
|
| 540 |
+
gaussian_importance_map[gaussian_importance_map == 0] = np.min(
|
| 541 |
+
gaussian_importance_map[gaussian_importance_map != 0])
|
| 542 |
+
return gaussian_importance_map
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
def sc_mask_to_mc_mask(sc_mask, label_values_ls):
|
| 546 |
+
"""
|
| 547 |
+
Convert single-channel mask to multi-channel mask.
|
| 548 |
+
|
| 549 |
+
Args:
|
| 550 |
+
sc_mask: Single-channel mask with shape (1, 1, h, w, d) or (h, w, d)
|
| 551 |
+
label_values_ls: List of label values to create channels for
|
| 552 |
+
|
| 553 |
+
Returns:
|
| 554 |
+
Multi-channel mask with shape (1, n_classes, h, w, d)
|
| 555 |
+
"""
|
| 556 |
+
sc_mask = sc_mask.squeeze(0).squeeze(0)
|
| 557 |
+
assert sc_mask.ndim == 3
|
| 558 |
+
h, w, d = sc_mask.shape
|
| 559 |
+
n = len(label_values_ls)
|
| 560 |
+
mc_mask = torch.zeros((n, h, w, d), dtype=bool).to(sc_mask.device)
|
| 561 |
+
for i, label_value in enumerate(label_values_ls):
|
| 562 |
+
mc_mask[i] = torch.where(sc_mask == label_value, 1, 0)
|
| 563 |
+
mc_mask = mc_mask.to(torch.float32)
|
| 564 |
+
mc_mask = mc_mask.unsqueeze(0)
|
| 565 |
+
return mc_mask
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
class MedicalSegmentationPipeline:
|
| 569 |
+
"""
|
| 570 |
+
Pipeline for medical image segmentation.
|
| 571 |
+
|
| 572 |
+
This class handles model loading, data preprocessing, and inference execution
|
| 573 |
+
for the Medal-S segmentation pipeline.
|
| 574 |
+
"""
|
| 575 |
+
|
| 576 |
+
def __init__(self, config):
|
| 577 |
+
"""
|
| 578 |
+
Initialize the segmentation pipeline.
|
| 579 |
+
|
| 580 |
+
Args:
|
| 581 |
+
config: Dictionary containing pipeline configuration parameters
|
| 582 |
+
"""
|
| 583 |
+
self.config = config
|
| 584 |
+
self.device = torch.device(config['device'])
|
| 585 |
+
|
| 586 |
+
def _load_model(self):
|
| 587 |
+
"""
|
| 588 |
+
Load vision model and text encoder from checkpoints.
|
| 589 |
+
|
| 590 |
+
Returns:
|
| 591 |
+
model: Loaded vision model (Maskformer)
|
| 592 |
+
text_encoder: Loaded text encoder (Knowledge_Encoder)
|
| 593 |
+
"""
|
| 594 |
+
crop_str = '_'.join(map(str, self.config['crop_size']))
|
| 595 |
+
spacing_str = '_'.join(map(str, self.config['target_spacing_model']))
|
| 596 |
+
|
| 597 |
+
vision_backbone_checkpoint = os.path.join(
|
| 598 |
+
self.config['checkpoints_path'],
|
| 599 |
+
f"nano_UNet_CVPR2025_crop_size_{crop_str}_spacing_{spacing_str}_step_{self.config['model_step']}.pth")
|
| 600 |
+
|
| 601 |
+
model = Maskformer(
|
| 602 |
+
self.config['vision_backbone'],
|
| 603 |
+
self.config['input_channels'],
|
| 604 |
+
self.config['crop_size'],
|
| 605 |
+
self.config['patch_size'],
|
| 606 |
+
False
|
| 607 |
+
)
|
| 608 |
+
model = model.to(self.device)
|
| 609 |
+
checkpoint = torch.load(vision_backbone_checkpoint, map_location=self.device)
|
| 610 |
+
new_state_dict = {
|
| 611 |
+
k[7:] if k.startswith('module.') else k: v
|
| 612 |
+
for k, v in checkpoint['model_state_dict'].items()
|
| 613 |
+
if 'mid_mask_embed_proj' not in k
|
| 614 |
+
}
|
| 615 |
+
model.load_state_dict(new_state_dict)
|
| 616 |
+
model.eval()
|
| 617 |
+
|
| 618 |
+
text_encoder = Knowledge_Encoder(
|
| 619 |
+
biolord_checkpoint=os.path.join(
|
| 620 |
+
self.config['checkpoints_path'],
|
| 621 |
+
'BioLORD-2023-C'
|
| 622 |
+
)
|
| 623 |
+
)
|
| 624 |
+
text_encoder = text_encoder.to(self.device)
|
| 625 |
+
checkpoint = torch.load(
|
| 626 |
+
os.path.join(self.config['checkpoints_path'], 'text_encoder.pth'),
|
| 627 |
+
map_location=self.device
|
| 628 |
+
)
|
| 629 |
+
new_state_dict = {
|
| 630 |
+
k[7:] if k.startswith('module.') else k: v
|
| 631 |
+
for k, v in checkpoint['model_state_dict'].items()
|
| 632 |
+
}
|
| 633 |
+
text_encoder.load_state_dict(new_state_dict, strict=False)
|
| 634 |
+
text_encoder.eval()
|
| 635 |
+
|
| 636 |
+
return model, text_encoder
|
| 637 |
+
|
| 638 |
+
def run_inference(self, raw_image, raw_spacing, verbose=True):
|
| 639 |
+
"""
|
| 640 |
+
Run inference on the input image.
|
| 641 |
+
|
| 642 |
+
This method performs the complete inference pipeline:
|
| 643 |
+
1. Load models (vision backbone and text encoder)
|
| 644 |
+
2. Preprocess image data (resampling, padding, patch splitting)
|
| 645 |
+
3. Encode text prompts
|
| 646 |
+
4. Process patches and aggregate predictions
|
| 647 |
+
5. Post-process results (remove padding, resample to original shape)
|
| 648 |
+
|
| 649 |
+
Args:
|
| 650 |
+
raw_image: Input image array with shape (d, h, w)
|
| 651 |
+
raw_spacing: Spacing array with shape (3,)
|
| 652 |
+
verbose: Whether to print detailed information (default: True)
|
| 653 |
+
|
| 654 |
+
Returns:
|
| 655 |
+
pred_array: Segmentation array with shape (d, h, w), dtype int16
|
| 656 |
+
max_prob_array: Maximum probability array (if return_max_prob=True), or None
|
| 657 |
+
"""
|
| 658 |
+
model, text_encoder = self._load_model()
|
| 659 |
+
pred_array = None
|
| 660 |
+
crop_size = self.config['crop_size']
|
| 661 |
+
disable_tta = self.config['disable_tta']
|
| 662 |
+
instance_label = self.config['instance_label']
|
| 663 |
+
modality = self.config['modality']
|
| 664 |
+
text_prompts = self.config['texts']
|
| 665 |
+
label_values = self.config['label_values']
|
| 666 |
+
return_max_prob = self.config['return_max_prob']
|
| 667 |
+
class_name_list = self.config['class_name_list']
|
| 668 |
+
stage_1_flag = self.config['stage_1_flag']
|
| 669 |
+
with torch.no_grad():
|
| 670 |
+
# Gaussian is kept on CPU, as accumulation will now happen on CPU
|
| 671 |
+
gaussian = torch.tensor(compute_gaussian(tuple(crop_size)), dtype=torch.float32).cpu()
|
| 672 |
+
|
| 673 |
+
data_dict = read_npz_data(
|
| 674 |
+
raw_image=raw_image,
|
| 675 |
+
raw_spacing=raw_spacing,
|
| 676 |
+
crop_size=crop_size,
|
| 677 |
+
target_spacing=self.config['target_spacing'],
|
| 678 |
+
scaled_roi_lowres_pred_array=self.config['scaled_roi_lowres_pred_array'],
|
| 679 |
+
class_name_list=class_name_list,
|
| 680 |
+
stage_1_flag=stage_1_flag,
|
| 681 |
+
device=self.device,
|
| 682 |
+
verbose=verbose
|
| 683 |
+
)
|
| 684 |
+
|
| 685 |
+
spacing = data_dict['spacing']
|
| 686 |
+
original_shape = data_dict['original_shape']
|
| 687 |
+
current_shape = data_dict['current_shape']
|
| 688 |
+
batched_patches = data_dict['patches']
|
| 689 |
+
batched_y1y2_x1x2_z1z2 = data_dict['y1y2_x1x2_z1z2_ls']
|
| 690 |
+
padding_info = data_dict['padding_info']
|
| 691 |
+
raw_image = data_dict['raw_image']
|
| 692 |
+
num_iterations = data_dict['num_iterations']
|
| 693 |
+
batched_lowres_pred_patches = data_dict.get('lowres_pred_patches')
|
| 694 |
+
|
| 695 |
+
modality_code = torch.tensor([{
|
| 696 |
+
'ct': 0, 'mri': 1, 'us': 2, 'pet': 3, 'microscopy': 4
|
| 697 |
+
}[modality]]).to(self.device) # Keep modality_code on GPU if text_encoder needs it on GPU
|
| 698 |
+
|
| 699 |
+
h, w, d = current_shape
|
| 700 |
+
n_total_classes = len(text_prompts)
|
| 701 |
+
|
| 702 |
+
# Get category batch size from config, default to 24
|
| 703 |
+
category_batch_size = self.config.get('category_batch_size', 24)
|
| 704 |
+
background_threshold = self.config.get('background_threshold', 0.5)
|
| 705 |
+
|
| 706 |
+
# Initialize max_prob and max_class_label_value on CPU to save GPU memory
|
| 707 |
+
max_prob = torch.zeros((h, w, d), dtype=torch.float32, device='cpu')
|
| 708 |
+
max_class_label_value = torch.zeros((h, w, d), dtype=torch.int16, device='cpu')
|
| 709 |
+
|
| 710 |
+
# Process categories in batches to avoid OOM
|
| 711 |
+
category_range = range(0, n_total_classes, category_batch_size)
|
| 712 |
+
pbar = tqdm(category_range, desc="Processing Categories")
|
| 713 |
+
for i in pbar:
|
| 714 |
+
current_category_texts = text_prompts[i:i + category_batch_size]
|
| 715 |
+
current_label_values = label_values[i:i + category_batch_size]
|
| 716 |
+
current_n = len(current_category_texts)
|
| 717 |
+
end_idx = min(i + current_n - 1, n_total_classes - 1)
|
| 718 |
+
|
| 719 |
+
# Update progress bar description with current category range
|
| 720 |
+
pbar.set_description(f"Processing Categories {i}-{end_idx}")
|
| 721 |
+
|
| 722 |
+
# Keep these large tensors on CPU for accumulation
|
| 723 |
+
temp_prediction_batch_cpu = torch.zeros((current_n, h, w, d), dtype=torch.float32, device='cpu')
|
| 724 |
+
temp_accumulation_batch_cpu = torch.zeros((current_n, h, w, d), dtype=torch.float32, device='cpu')
|
| 725 |
+
|
| 726 |
+
# Encode text prompts for current batch
|
| 727 |
+
with autocast(enabled=False):
|
| 728 |
+
queries = text_encoder(current_category_texts, modality_code, self.device) # queries remain on GPU for model input
|
| 729 |
+
|
| 730 |
+
# Process patches for current category batch
|
| 731 |
+
for patches, lowres_pred_patches, y1y2_x1x2_z1z2_ls in tqdm(
|
| 732 |
+
zip(batched_patches, batched_lowres_pred_patches if batched_lowres_pred_patches is not None else [None]*len(batched_patches), batched_y1y2_x1x2_z1z2),
|
| 733 |
+
total=len(batched_patches),
|
| 734 |
+
desc="Processing",
|
| 735 |
+
ncols=100,
|
| 736 |
+
bar_format="{l_bar}{bar:20}{r_bar}",
|
| 737 |
+
colour="green",
|
| 738 |
+
leave=False
|
| 739 |
+
):
|
| 740 |
+
patches = patches.unsqueeze(0).to(device=self.device, dtype=torch.float32) # patches on GPU for model input
|
| 741 |
+
y1, y2, x1, x2, z1, z2 = y1y2_x1x2_z1z2_ls
|
| 742 |
+
|
| 743 |
+
simulated_lowres_sc_pred = None
|
| 744 |
+
simulated_lowres_mc_pred = None
|
| 745 |
+
|
| 746 |
+
if not self.config['w_lowres_pred_prompts']:
|
| 747 |
+
simulated_lowres_sc_pred = torch.zeros((1, 1, *crop_size), device=self.device, dtype=torch.float32)
|
| 748 |
+
simulated_lowres_mc_pred = torch.zeros((1, current_n, *crop_size), device=self.device, dtype=torch.float32)
|
| 749 |
+
prediction_patch = model(
|
| 750 |
+
queries=queries,
|
| 751 |
+
image_input=patches,
|
| 752 |
+
simulated_lowres_sc_pred=simulated_lowres_sc_pred,
|
| 753 |
+
simulated_lowres_mc_pred=simulated_lowres_mc_pred,
|
| 754 |
+
train_mode=False
|
| 755 |
+
) if self.config['disable_tta'] else internal_maybe_mirror_and_predict(
|
| 756 |
+
model=model,
|
| 757 |
+
queries=queries,
|
| 758 |
+
image_input=patches,
|
| 759 |
+
simulated_lowres_sc_pred=simulated_lowres_sc_pred,
|
| 760 |
+
simulated_lowres_mc_pred=simulated_lowres_mc_pred,
|
| 761 |
+
mirror_axes=(0, 1, 2)
|
| 762 |
+
)
|
| 763 |
+
else:
|
| 764 |
+
lowres_pred_patches = lowres_pred_patches.unsqueeze(0).to(device=self.device, dtype=torch.float32)
|
| 765 |
+
simulated_lowres_sc_pred = torch.where(lowres_pred_patches > 0, torch.ones_like(lowres_pred_patches), torch.zeros_like(lowres_pred_patches))
|
| 766 |
+
simulated_lowres_mc_pred = sc_mask_to_mc_mask(lowres_pred_patches, [int(val) for val in current_label_values])
|
| 767 |
+
|
| 768 |
+
possible_block_sizes = [8]
|
| 769 |
+
if instance_label == 1:
|
| 770 |
+
n_repeats = 1
|
| 771 |
+
else:
|
| 772 |
+
n_repeats = 1
|
| 773 |
+
prediction_patch = compute_patch_prediction(queries, patches, simulated_lowres_sc_pred, simulated_lowres_mc_pred, model, possible_block_sizes, n_repeats, disable_tta)
|
| 774 |
+
|
| 775 |
+
if instance_label == 1: # Instance segmentation mode
|
| 776 |
+
for _ in range(num_iterations):
|
| 777 |
+
prediction_patch_prob = torch.sigmoid(prediction_patch).detach()
|
| 778 |
+
simulated_lowres_mc_pred = torch.where(prediction_patch_prob > 0.5, 1.0, 0.0)
|
| 779 |
+
simulated_lowres_sc_pred = (simulated_lowres_mc_pred.sum(dim=1, keepdim=True) > 0).float()
|
| 780 |
+
possible_block_sizes = [4]
|
| 781 |
+
n_repeats = 1
|
| 782 |
+
prediction_patch = compute_patch_prediction(queries, patches, simulated_lowres_sc_pred, simulated_lowres_mc_pred, model, possible_block_sizes, n_repeats, disable_tta)
|
| 783 |
+
|
| 784 |
+
prediction_patch_prob_gpu = torch.sigmoid(prediction_patch).detach()
|
| 785 |
+
current_gaussian_slice = gaussian[:y2-y1, :x2-x1, :z2-z1] # Already on CPU
|
| 786 |
+
|
| 787 |
+
# Perform accumulation on CPU. Move prediction_patch_prob_gpu to CPU here.
|
| 788 |
+
temp_prediction_batch_cpu[:, y1:y2, x1:x2, z1:z2] += (prediction_patch_prob_gpu[0, :, :y2-y1, :x2-x1, :z2-z1].cpu() * current_gaussian_slice)
|
| 789 |
+
temp_accumulation_batch_cpu[:, y1:y2, x1:x2, z1:z2] += current_gaussian_slice
|
| 790 |
+
|
| 791 |
+
# Explicitly delete GPU tensors to free up memory immediately
|
| 792 |
+
del prediction_patch, prediction_patch_prob_gpu, patches
|
| 793 |
+
if simulated_lowres_sc_pred is not None:
|
| 794 |
+
del simulated_lowres_sc_pred
|
| 795 |
+
if simulated_lowres_mc_pred is not None:
|
| 796 |
+
del simulated_lowres_mc_pred
|
| 797 |
+
torch.cuda.empty_cache() # Clear any cached GPU memory after each patch processing
|
| 798 |
+
gc.collect() # Python garbage collection
|
| 799 |
+
|
| 800 |
+
# Normalize predictions by accumulation
|
| 801 |
+
batch_accumulation_cpu = temp_accumulation_batch_cpu
|
| 802 |
+
batch_accumulation_cpu[batch_accumulation_cpu == 0] = 1e-8
|
| 803 |
+
batch_prediction_prob_cpu = temp_prediction_batch_cpu / batch_accumulation_cpu
|
| 804 |
+
|
| 805 |
+
# Update max_prob and max_class_label_value on CPU
|
| 806 |
+
for j in range(current_n):
|
| 807 |
+
class_prob_cpu = batch_prediction_prob_cpu[j, ...] # Already on CPU
|
| 808 |
+
class_label_value_cpu_scalar = torch.tensor(int(current_label_values[j]), dtype=torch.int16, device='cpu') # Already on CPU
|
| 809 |
+
|
| 810 |
+
update_mask_cpu = class_prob_cpu > max_prob
|
| 811 |
+
max_prob[update_mask_cpu] = class_prob_cpu[update_mask_cpu]
|
| 812 |
+
max_class_label_value[update_mask_cpu] = class_label_value_cpu_scalar
|
| 813 |
+
|
| 814 |
+
# Clean up batch tensors
|
| 815 |
+
del temp_prediction_batch_cpu, temp_accumulation_batch_cpu, batch_accumulation_cpu, batch_prediction_prob_cpu, queries
|
| 816 |
+
# Previous patch-level deletions handle GPU memory
|
| 817 |
+
|
| 818 |
+
# Final operations on CPU
|
| 819 |
+
background_indices = max_prob < background_threshold
|
| 820 |
+
max_class_label_value[background_indices] = 0
|
| 821 |
+
results = max_class_label_value.numpy() # Already on CPU, just convert to numpy
|
| 822 |
+
|
| 823 |
+
results = remove_padding(results, padding_info)
|
| 824 |
+
current_h, current_w, current_d = results.shape
|
| 825 |
+
if results.shape != original_shape:
|
| 826 |
+
results = resample_torch_simple(
|
| 827 |
+
results[np.newaxis, ...],
|
| 828 |
+
new_shape=original_shape,
|
| 829 |
+
is_seg=True,
|
| 830 |
+
num_threads=4,
|
| 831 |
+
device=torch.device('cpu'),
|
| 832 |
+
memefficient_seg_resampling=False).squeeze(0)
|
| 833 |
+
|
| 834 |
+
if verbose:
|
| 835 |
+
print(f"Resized segmentation from {current_h, current_w, current_d} to {original_shape}")
|
| 836 |
+
|
| 837 |
+
pred_array = rearrange(results, 'h w d -> d h w').astype(np.int16)
|
| 838 |
+
|
| 839 |
+
if return_max_prob and instance_label == 0:
|
| 840 |
+
# max_prob is already on CPU, just convert to numpy for post-processing
|
| 841 |
+
max_prob_numpy = max_prob.numpy()
|
| 842 |
+
max_prob_numpy = remove_padding(max_prob_numpy, padding_info)
|
| 843 |
+
current_h, current_w, current_d = max_prob_numpy.shape
|
| 844 |
+
if max_prob_numpy.shape != original_shape:
|
| 845 |
+
max_prob_numpy = resample_torch_simple(
|
| 846 |
+
max_prob_numpy[np.newaxis, ...],
|
| 847 |
+
new_shape=original_shape,
|
| 848 |
+
is_seg=False,
|
| 849 |
+
num_threads=4,
|
| 850 |
+
device=torch.device('cpu'),
|
| 851 |
+
memefficient_seg_resampling=False).squeeze(0)
|
| 852 |
+
|
| 853 |
+
if verbose:
|
| 854 |
+
print(f"Resized max probability from {current_h, current_w, current_d} to {original_shape}")
|
| 855 |
+
max_prob = rearrange(max_prob_numpy, 'h w d -> d h w').astype(np.float32)
|
| 856 |
+
|
| 857 |
+
if return_max_prob and instance_label == 0:
|
| 858 |
+
return pred_array, max_prob
|
| 859 |
+
else:
|
| 860 |
+
return pred_array, None
|
| 861 |
+
|
| 862 |
+
|
| 863 |
+
def run_segmentation(
|
| 864 |
+
raw_image,
|
| 865 |
+
raw_spacing,
|
| 866 |
+
crop_size=[192, 192, 96],
|
| 867 |
+
target_spacing=[1.5, 1.5, 3.0],
|
| 868 |
+
target_spacing_model=[1.5, 1.5, 3.0],
|
| 869 |
+
w_lowres_pred_prompts=False,
|
| 870 |
+
scaled_roi_lowres_pred_array=None,
|
| 871 |
+
disable_tta=True,
|
| 872 |
+
model_step=100000,
|
| 873 |
+
vision_backbone="UNET",
|
| 874 |
+
input_channels=2,
|
| 875 |
+
patch_size=[32, 32, 32],
|
| 876 |
+
modality='CT',
|
| 877 |
+
instance_label=0,
|
| 878 |
+
texts=[],
|
| 879 |
+
label_values=[],
|
| 880 |
+
return_max_prob=False,
|
| 881 |
+
class_name_list=[],
|
| 882 |
+
stage_1_flag=False,
|
| 883 |
+
device="cuda:0",
|
| 884 |
+
checkpoints_path="./checkpoints",
|
| 885 |
+
category_batch_size=24,
|
| 886 |
+
background_threshold=0.5,
|
| 887 |
+
verbose=True,
|
| 888 |
+
):
|
| 889 |
+
"""
|
| 890 |
+
Main segmentation function.
|
| 891 |
+
|
| 892 |
+
This function orchestrates the entire segmentation pipeline including
|
| 893 |
+
model loading, data preprocessing, patch-based inference, and result aggregation.
|
| 894 |
+
|
| 895 |
+
Args:
|
| 896 |
+
raw_image: Input image array with shape (d, h, w), dtype uint8, values in [0, 255]
|
| 897 |
+
raw_spacing: Spacing array with shape (3,)
|
| 898 |
+
crop_size: Crop size for patch processing [h, w, d]
|
| 899 |
+
target_spacing: Target spacing for resampling [h, w, d]
|
| 900 |
+
target_spacing_model: Target spacing for model (should match target_spacing)
|
| 901 |
+
w_lowres_pred_prompts: Whether to use low-res predictions as spatial prompts
|
| 902 |
+
scaled_roi_lowres_pred_array: Low-res prediction array for spatial prompts
|
| 903 |
+
disable_tta: Disable test-time augmentation
|
| 904 |
+
model_step: Model checkpoint step number
|
| 905 |
+
vision_backbone: Vision backbone architecture name
|
| 906 |
+
input_channels: Number of input channels
|
| 907 |
+
patch_size: Patch size for the model
|
| 908 |
+
modality: Imaging modality ('CT', 'MRI', 'US', 'PET', 'microscopy')
|
| 909 |
+
instance_label: 0 for semantic segmentation, 1 for instance segmentation
|
| 910 |
+
texts: List of text prompts (one per class)
|
| 911 |
+
label_values: List of label values (one per class)
|
| 912 |
+
return_max_prob: Whether to return maximum probability map
|
| 913 |
+
class_name_list: List of class names for class-specific adjustments
|
| 914 |
+
stage_1_flag: Whether this is Stage 1 inference
|
| 915 |
+
device: Device string (e.g., 'cuda:0' or 'cpu')
|
| 916 |
+
checkpoints_path: Path to model checkpoints directory
|
| 917 |
+
category_batch_size: Number of categories to process in each batch (default: 24)
|
| 918 |
+
Adjust based on GPU memory. Larger 3D images require smaller batch sizes.
|
| 919 |
+
Accumulation operations are performed on CPU for more stable memory usage.
|
| 920 |
+
background_threshold: Probability threshold for background (default: 0.5)
|
| 921 |
+
Voxels with max probability below this threshold will be labeled as background.
|
| 922 |
+
verbose: Whether to print detailed information (default: True)
|
| 923 |
+
|
| 924 |
+
Returns:
|
| 925 |
+
pred_array: Segmentation array with shape (d, h, w), dtype int16
|
| 926 |
+
max_prob_array: Maximum probability array (if return_max_prob=True), or None
|
| 927 |
+
"""
|
| 928 |
+
w_lowres_pred_prompts = scaled_roi_lowres_pred_array is not None
|
| 929 |
+
config = {
|
| 930 |
+
'device': device,
|
| 931 |
+
'modality': modality,
|
| 932 |
+
'instance_label': instance_label,
|
| 933 |
+
'texts': texts,
|
| 934 |
+
'label_values': label_values,
|
| 935 |
+
'vision_backbone': vision_backbone,
|
| 936 |
+
'crop_size': crop_size,
|
| 937 |
+
'patch_size': patch_size,
|
| 938 |
+
'target_spacing': target_spacing,
|
| 939 |
+
'target_spacing_model': target_spacing_model,
|
| 940 |
+
'model_step': model_step,
|
| 941 |
+
'input_channels': input_channels,
|
| 942 |
+
'w_lowres_pred_prompts': w_lowres_pred_prompts,
|
| 943 |
+
'scaled_roi_lowres_pred_array': scaled_roi_lowres_pred_array,
|
| 944 |
+
'disable_tta': disable_tta,
|
| 945 |
+
'checkpoints_path': checkpoints_path,
|
| 946 |
+
'return_max_prob': return_max_prob,
|
| 947 |
+
'class_name_list': class_name_list,
|
| 948 |
+
'stage_1_flag': stage_1_flag,
|
| 949 |
+
'category_batch_size': category_batch_size,
|
| 950 |
+
'background_threshold': background_threshold,
|
| 951 |
+
}
|
| 952 |
+
|
| 953 |
+
pipeline = MedicalSegmentationPipeline(config)
|
| 954 |
+
return pipeline.run_inference(raw_image, raw_spacing, verbose=verbose)
|
| 955 |
+
|
| 956 |
+
|
| 957 |
+
# ============================================================================
|
| 958 |
+
# Main Inference Functions
|
| 959 |
+
# ============================================================================
|
| 960 |
+
# These functions provide the high-level interface for running inference
|
| 961 |
+
# on raw NIfTI images with proper preprocessing and post-processing.
|
| 962 |
+
# ============================================================================
|
| 963 |
+
|
| 964 |
+
|
| 965 |
+
def normalize_image_ct(image_data, window_level=40, window_width=400, window_type='soft_tissue'):
|
| 966 |
+
"""
|
| 967 |
+
Normalize CT image using window/level technique.
|
| 968 |
+
|
| 969 |
+
Args:
|
| 970 |
+
image_data: Input CT image array
|
| 971 |
+
window_level: Window level (center of the window). If None, will use default based on window_type
|
| 972 |
+
window_width: Window width (range of the window). If None, will use default based on window_type
|
| 973 |
+
window_type: Type of window ('soft_tissue', 'bone', 'lung'). Used if window_level/window_width are None
|
| 974 |
+
|
| 975 |
+
Returns:
|
| 976 |
+
Normalized image array with dtype uint8, values in [0, 255]
|
| 977 |
+
"""
|
| 978 |
+
# Default window settings for different window types
|
| 979 |
+
default_windows = {
|
| 980 |
+
'soft_tissue': {'window_level': 40, 'window_width': 400},
|
| 981 |
+
'bone': {'window_level': 500, 'window_width': 1500},
|
| 982 |
+
'lung': {'window_level': -600, 'window_width': 1500}
|
| 983 |
+
}
|
| 984 |
+
|
| 985 |
+
# Use defaults if not provided
|
| 986 |
+
if window_level is None or window_width is None:
|
| 987 |
+
if window_type in default_windows:
|
| 988 |
+
window_level = default_windows[window_type]['window_level']
|
| 989 |
+
window_width = default_windows[window_type]['window_width']
|
| 990 |
+
else:
|
| 991 |
+
# Fallback to soft_tissue defaults
|
| 992 |
+
window_level = default_windows['soft_tissue']['window_level']
|
| 993 |
+
window_width = default_windows['soft_tissue']['window_width']
|
| 994 |
+
|
| 995 |
+
lower_bound = window_level - window_width / 2
|
| 996 |
+
upper_bound = window_level + window_width / 2
|
| 997 |
+
image_data_pre = np.clip(image_data, lower_bound, upper_bound)
|
| 998 |
+
image_data_pre = (
|
| 999 |
+
(image_data_pre - np.min(image_data_pre))
|
| 1000 |
+
/ (np.max(image_data_pre) - np.min(image_data_pre) + 1e-8)
|
| 1001 |
+
* 255.0
|
| 1002 |
+
)
|
| 1003 |
+
return image_data_pre.astype(np.uint8)
|
| 1004 |
+
|
| 1005 |
+
|
| 1006 |
+
def normalize_image_other(image_data, percentile_lower=None, percentile_upper=None, preserve_zero=None, normalization_settings=None):
|
| 1007 |
+
"""
|
| 1008 |
+
Normalize non-CT images using percentile-based normalization.
|
| 1009 |
+
|
| 1010 |
+
This method clips values to specified percentiles, then
|
| 1011 |
+
normalizes to [0, 255] range while optionally preserving zero values.
|
| 1012 |
+
|
| 1013 |
+
Args:
|
| 1014 |
+
image_data: Input image array
|
| 1015 |
+
percentile_lower: Lower percentile for clipping. If None, will use default or value from normalization_settings
|
| 1016 |
+
percentile_upper: Upper percentile for clipping. If None, will use default or value from normalization_settings
|
| 1017 |
+
preserve_zero: Whether to preserve zero values. If None, will use default or value from normalization_settings
|
| 1018 |
+
normalization_settings: Dictionary containing normalization settings from config.
|
| 1019 |
+
Format: {'percentile_lower': 0.5, 'percentile_upper': 99.5, 'preserve_zero': True}
|
| 1020 |
+
|
| 1021 |
+
Returns:
|
| 1022 |
+
Normalized image array with dtype uint8, values in [0, 255]
|
| 1023 |
+
"""
|
| 1024 |
+
# Default normalization settings
|
| 1025 |
+
default_percentile_lower = 0.5
|
| 1026 |
+
default_percentile_upper = 99.5
|
| 1027 |
+
default_preserve_zero = True
|
| 1028 |
+
|
| 1029 |
+
# Use settings from config if provided
|
| 1030 |
+
if normalization_settings is not None:
|
| 1031 |
+
if percentile_lower is None:
|
| 1032 |
+
percentile_lower = normalization_settings.get('percentile_lower', default_percentile_lower)
|
| 1033 |
+
if percentile_upper is None:
|
| 1034 |
+
percentile_upper = normalization_settings.get('percentile_upper', default_percentile_upper)
|
| 1035 |
+
if preserve_zero is None:
|
| 1036 |
+
preserve_zero = normalization_settings.get('preserve_zero', default_preserve_zero)
|
| 1037 |
+
else:
|
| 1038 |
+
# Use defaults if not provided
|
| 1039 |
+
if percentile_lower is None:
|
| 1040 |
+
percentile_lower = default_percentile_lower
|
| 1041 |
+
if percentile_upper is None:
|
| 1042 |
+
percentile_upper = default_percentile_upper
|
| 1043 |
+
if preserve_zero is None:
|
| 1044 |
+
preserve_zero = default_preserve_zero
|
| 1045 |
+
|
| 1046 |
+
# Calculate percentiles from non-zero values
|
| 1047 |
+
non_zero_data = image_data[image_data > 0]
|
| 1048 |
+
if len(non_zero_data) > 0:
|
| 1049 |
+
lower_bound, upper_bound = np.percentile(
|
| 1050 |
+
non_zero_data, [percentile_lower, percentile_upper]
|
| 1051 |
+
)
|
| 1052 |
+
else:
|
| 1053 |
+
# If all values are zero, use min/max
|
| 1054 |
+
lower_bound = np.min(image_data)
|
| 1055 |
+
upper_bound = np.max(image_data)
|
| 1056 |
+
|
| 1057 |
+
image_data_pre = np.clip(image_data, lower_bound, upper_bound)
|
| 1058 |
+
image_data_pre = (
|
| 1059 |
+
(image_data_pre - np.min(image_data_pre))
|
| 1060 |
+
/ (np.max(image_data_pre) - np.min(image_data_pre) + 1e-8)
|
| 1061 |
+
* 255.0
|
| 1062 |
+
)
|
| 1063 |
+
|
| 1064 |
+
if preserve_zero:
|
| 1065 |
+
image_data_pre[image_data == 0] = 0
|
| 1066 |
+
|
| 1067 |
+
return image_data_pre.astype(np.uint8)
|
| 1068 |
+
|
| 1069 |
+
|
| 1070 |
+
def load_nifti_image(image_path):
|
| 1071 |
+
"""
|
| 1072 |
+
Load NIfTI image and extract data, spacing, and metadata.
|
| 1073 |
+
|
| 1074 |
+
Args:
|
| 1075 |
+
image_path: Path to NIfTI image file
|
| 1076 |
+
|
| 1077 |
+
Returns:
|
| 1078 |
+
image_data: Image array with shape (d, h, w)
|
| 1079 |
+
spacing_xyz: Spacing tuple (x, y, z) from SimpleITK
|
| 1080 |
+
metadata: Dictionary containing origin, direction, and spacing_xyz
|
| 1081 |
+
"""
|
| 1082 |
+
img_sitk = sitk.ReadImage(image_path)
|
| 1083 |
+
image_data = sitk.GetArrayFromImage(img_sitk) # Shape: (d, h, w)
|
| 1084 |
+
spacing_xyz = img_sitk.GetSpacing() # (x, y, z)
|
| 1085 |
+
|
| 1086 |
+
# Save metadata for output
|
| 1087 |
+
metadata = {
|
| 1088 |
+
'origin': img_sitk.GetOrigin(),
|
| 1089 |
+
'direction': img_sitk.GetDirection(),
|
| 1090 |
+
'spacing_xyz': spacing_xyz
|
| 1091 |
+
}
|
| 1092 |
+
|
| 1093 |
+
return image_data, spacing_xyz, metadata
|
| 1094 |
+
|
| 1095 |
+
|
| 1096 |
+
def convert_spacing(spacing_xyz, image_shape):
|
| 1097 |
+
"""
|
| 1098 |
+
Convert spacing from SimpleITK format (x, y, z) to format expected by run_segmentation.
|
| 1099 |
+
|
| 1100 |
+
Following the conversion logic from inference_raw_nifti_2.py:
|
| 1101 |
+
1. SimpleITK returns (x, y, z)
|
| 1102 |
+
2. Image from SimpleITK is (d, h, w) where d=z, h=y, w=x
|
| 1103 |
+
3. Convert to (d, h, w) spacing: (z, x, y) = (d, h, w)
|
| 1104 |
+
4. Then convert to format expected by run_segmentation: (h, w, d)
|
| 1105 |
+
|
| 1106 |
+
Args:
|
| 1107 |
+
spacing_xyz: Spacing tuple from SimpleITK (x, y, z)
|
| 1108 |
+
image_shape: Image shape (d, h, w)
|
| 1109 |
+
|
| 1110 |
+
Returns:
|
| 1111 |
+
img_spacing: Spacing array in format expected by run_segmentation
|
| 1112 |
+
"""
|
| 1113 |
+
img_spacing = np.array(spacing_xyz, dtype=np.float32)
|
| 1114 |
+
|
| 1115 |
+
# Step 1: Convert from (x, y, z) to (d, h, w) spacing
|
| 1116 |
+
# SimpleITK: (x, y, z) -> Image: (d, h, w) where d=z, h=y, w=x
|
| 1117 |
+
# So spacing (x, y, z) -> (z, x, y) = (d, h, w)
|
| 1118 |
+
img_spacing_transposed = img_spacing[[2, 0, 1]] # (z, x, y) = (d, h, w)
|
| 1119 |
+
|
| 1120 |
+
# Step 2: Handle very small spacing values
|
| 1121 |
+
for i in range(3):
|
| 1122 |
+
if img_spacing_transposed[i] < 0.1:
|
| 1123 |
+
img_spacing_transposed[i] = 1.0
|
| 1124 |
+
|
| 1125 |
+
# Step 3: Optional: Adjust spacing based on image dimensions
|
| 1126 |
+
# Note: adjust_spacing expects image in (h, w, d) format, so we need to rearrange
|
| 1127 |
+
# For now, we'll skip this adjustment or use a dummy array
|
| 1128 |
+
try:
|
| 1129 |
+
img_spacing_transposed = adjust_spacing(
|
| 1130 |
+
np.zeros(image_shape), # Dummy array for shape reference
|
| 1131 |
+
img_spacing_transposed
|
| 1132 |
+
).astype(np.float32)
|
| 1133 |
+
except Exception:
|
| 1134 |
+
# If adjust_spacing fails, use spacing as-is
|
| 1135 |
+
pass
|
| 1136 |
+
|
| 1137 |
+
# Step 4: Convert to format expected by run_segmentation
|
| 1138 |
+
# This converts (d, h, w) to (h, w, d)
|
| 1139 |
+
img_spacing = img_spacing_transposed[[1, 2, 0]]
|
| 1140 |
+
|
| 1141 |
+
return img_spacing
|
| 1142 |
+
|
| 1143 |
+
|
| 1144 |
+
def run_inference_single_window(
|
| 1145 |
+
image_data,
|
| 1146 |
+
spacing_xyz,
|
| 1147 |
+
metadata,
|
| 1148 |
+
modality='CT',
|
| 1149 |
+
texts=None,
|
| 1150 |
+
label_values=None,
|
| 1151 |
+
inference_mode='stage2_only',
|
| 1152 |
+
device="cuda:0",
|
| 1153 |
+
checkpoints_path="./checkpoints",
|
| 1154 |
+
window_settings=None,
|
| 1155 |
+
window_type='soft_tissue',
|
| 1156 |
+
normalization_settings=None,
|
| 1157 |
+
verbose=True
|
| 1158 |
+
):
|
| 1159 |
+
"""
|
| 1160 |
+
Run inference for a single window type.
|
| 1161 |
+
|
| 1162 |
+
This is an internal function used by run_inference to handle single window type inference.
|
| 1163 |
+
|
| 1164 |
+
Args:
|
| 1165 |
+
image_data: Raw image data array (d, h, w)
|
| 1166 |
+
spacing_xyz: Spacing tuple (x, y, z)
|
| 1167 |
+
metadata: Image metadata dictionary
|
| 1168 |
+
modality: Imaging modality ('CT', 'MRI', 'US', 'PET', 'microscopy')
|
| 1169 |
+
texts: List of text prompts (one per class)
|
| 1170 |
+
label_values: List of label values (one per class)
|
| 1171 |
+
inference_mode: Inference mode ('stage2_only' or 'stage1+stage2')
|
| 1172 |
+
device: Device to use ('cuda:0' or 'cpu')
|
| 1173 |
+
checkpoints_path: Path to model checkpoints
|
| 1174 |
+
window_settings: Dictionary containing window settings for different window types (CT only)
|
| 1175 |
+
window_type: Type of window to use ('soft_tissue', 'bone', 'lung')
|
| 1176 |
+
normalization_settings: Dictionary containing normalization settings for non-CT modalities
|
| 1177 |
+
verbose: Whether to print detailed information (default: True)
|
| 1178 |
+
|
| 1179 |
+
Returns:
|
| 1180 |
+
pred_array: Segmentation array (d, h, w)
|
| 1181 |
+
"""
|
| 1182 |
+
if texts is None:
|
| 1183 |
+
texts = []
|
| 1184 |
+
if label_values is None:
|
| 1185 |
+
label_values = []
|
| 1186 |
+
|
| 1187 |
+
if len(texts) != len(label_values):
|
| 1188 |
+
raise ValueError("Number of text prompts must match number of label values")
|
| 1189 |
+
|
| 1190 |
+
# Normalize image
|
| 1191 |
+
if verbose:
|
| 1192 |
+
print(f"Normalizing image for {window_type} window (modality: {modality})")
|
| 1193 |
+
if modality.upper() == 'CT':
|
| 1194 |
+
# Get window settings from config if available
|
| 1195 |
+
window_level = None
|
| 1196 |
+
window_width = None
|
| 1197 |
+
if window_settings is not None and window_type in window_settings:
|
| 1198 |
+
window_level = window_settings[window_type].get('window_level')
|
| 1199 |
+
window_width = window_settings[window_type].get('window_width')
|
| 1200 |
+
if verbose:
|
| 1201 |
+
print(f"Using {window_type} window: level={window_level}, width={window_width}")
|
| 1202 |
+
|
| 1203 |
+
img_array = normalize_image_ct(image_data, window_level=window_level,
|
| 1204 |
+
window_width=window_width, window_type=window_type)
|
| 1205 |
+
else:
|
| 1206 |
+
# Get normalization settings from config if available
|
| 1207 |
+
if normalization_settings is not None:
|
| 1208 |
+
if verbose:
|
| 1209 |
+
print(f"Using normalization settings from config: {normalization_settings}")
|
| 1210 |
+
img_array = normalize_image_other(image_data, normalization_settings=normalization_settings)
|
| 1211 |
+
else:
|
| 1212 |
+
# Use default normalization
|
| 1213 |
+
if verbose:
|
| 1214 |
+
print("Using default normalization settings")
|
| 1215 |
+
img_array = normalize_image_other(image_data)
|
| 1216 |
+
|
| 1217 |
+
if verbose:
|
| 1218 |
+
print(f"Normalized image range: [{img_array.min()}, {img_array.max()}]")
|
| 1219 |
+
|
| 1220 |
+
# Convert spacing
|
| 1221 |
+
img_spacing = convert_spacing(spacing_xyz, img_array.shape)
|
| 1222 |
+
if verbose:
|
| 1223 |
+
print(f"Converted spacing: {img_spacing}")
|
| 1224 |
+
|
| 1225 |
+
# Run inference
|
| 1226 |
+
if inference_mode == 'stage1+stage2':
|
| 1227 |
+
if verbose:
|
| 1228 |
+
print(f"Running two-stage inference with {window_type} window...")
|
| 1229 |
+
# Stage 1: Low-resolution
|
| 1230 |
+
if verbose:
|
| 1231 |
+
print("Stage 1: Low-resolution segmentation...")
|
| 1232 |
+
stage_1_pred, _ = run_segmentation(
|
| 1233 |
+
raw_image=img_array,
|
| 1234 |
+
raw_spacing=img_spacing,
|
| 1235 |
+
crop_size=[224, 224, 128],
|
| 1236 |
+
target_spacing=[1.5, 1.5, 3.0],
|
| 1237 |
+
target_spacing_model=[1.5, 1.5, 3.0],
|
| 1238 |
+
w_lowres_pred_prompts=False,
|
| 1239 |
+
scaled_roi_lowres_pred_array=None,
|
| 1240 |
+
disable_tta=True,
|
| 1241 |
+
model_step=358600,
|
| 1242 |
+
modality=modality.lower(),
|
| 1243 |
+
instance_label=0,
|
| 1244 |
+
texts=texts,
|
| 1245 |
+
label_values=label_values,
|
| 1246 |
+
return_max_prob=False,
|
| 1247 |
+
class_name_list=[],
|
| 1248 |
+
stage_1_flag=True,
|
| 1249 |
+
device=device,
|
| 1250 |
+
checkpoints_path=checkpoints_path,
|
| 1251 |
+
verbose=verbose
|
| 1252 |
+
)
|
| 1253 |
+
|
| 1254 |
+
# Check if Stage 1 found anything
|
| 1255 |
+
if stage_1_pred.sum() == 0:
|
| 1256 |
+
if verbose:
|
| 1257 |
+
print("Warning: Stage 1 found no predictions. Using Stage 1 result as final output.")
|
| 1258 |
+
final_pred = stage_1_pred
|
| 1259 |
+
else:
|
| 1260 |
+
if verbose:
|
| 1261 |
+
print("Stage 1 completed. Extracting ROI for Stage 2...")
|
| 1262 |
+
|
| 1263 |
+
# Remove small objects from Stage 1 prediction
|
| 1264 |
+
min_size = 10
|
| 1265 |
+
lowres_pred_binary = (stage_1_pred > 0).astype(np.int16)
|
| 1266 |
+
lowres_pred_binary = remove_small_objects_binary(lowres_pred_binary, min_size=min_size).astype(np.int16)
|
| 1267 |
+
stage_1_pred_cleaned = stage_1_pred * lowres_pred_binary
|
| 1268 |
+
|
| 1269 |
+
# Extract ROI from Stage 1 prediction
|
| 1270 |
+
# Find bounding box of non-zero regions
|
| 1271 |
+
non_zero_indices = np.argwhere(stage_1_pred_cleaned > 0)
|
| 1272 |
+
if len(non_zero_indices) == 0:
|
| 1273 |
+
if verbose:
|
| 1274 |
+
print("Warning: No non-zero regions after cleaning. Using Stage 1 result.")
|
| 1275 |
+
final_pred = stage_1_pred_cleaned
|
| 1276 |
+
else:
|
| 1277 |
+
z_min, y_min, x_min = non_zero_indices.min(axis=0)
|
| 1278 |
+
z_max, y_max, x_max = non_zero_indices.max(axis=0)
|
| 1279 |
+
|
| 1280 |
+
# Calculate ROI center and range with scaling factor
|
| 1281 |
+
m = 1.1 # Scaling factor for ROI expansion
|
| 1282 |
+
z_center = (z_min + z_max) / 2
|
| 1283 |
+
y_center = (y_min + y_max) / 2
|
| 1284 |
+
x_center = (x_min + x_max) / 2
|
| 1285 |
+
|
| 1286 |
+
z_range = (z_max - z_min + 1) * m / 2
|
| 1287 |
+
y_range = (y_max - y_min + 1) * m / 2
|
| 1288 |
+
x_range = (x_max - x_min + 1) * m / 2
|
| 1289 |
+
|
| 1290 |
+
# Calculate minimum ranges based on Stage 2 crop size and spacing
|
| 1291 |
+
stage_2_crop_size = [192, 192, 192]
|
| 1292 |
+
stage_2_target_spacing = [1.0, 1.0, 1.0]
|
| 1293 |
+
|
| 1294 |
+
img_spacing_for_roi = img_spacing.copy()
|
| 1295 |
+
|
| 1296 |
+
min_z_range = (stage_2_crop_size[2] / 2) * stage_2_target_spacing[2] / img_spacing_for_roi[2] if img_spacing_for_roi[2] > 0 else z_range
|
| 1297 |
+
min_y_range = (stage_2_crop_size[0] / 2) * stage_2_target_spacing[0] / img_spacing_for_roi[0] if img_spacing_for_roi[0] > 0 else y_range
|
| 1298 |
+
min_x_range = (stage_2_crop_size[1] / 2) * stage_2_target_spacing[1] / img_spacing_for_roi[1] if img_spacing_for_roi[1] > 0 else x_range
|
| 1299 |
+
|
| 1300 |
+
z_range = max(min_z_range - 1, z_range)
|
| 1301 |
+
y_range = max(min_y_range - 1, y_range)
|
| 1302 |
+
x_range = max(min_x_range - 1, x_range)
|
| 1303 |
+
|
| 1304 |
+
z_min_new = max(0, int(z_center - z_range))
|
| 1305 |
+
z_max_new = min(stage_1_pred_cleaned.shape[0] - 1, int(z_center + z_range))
|
| 1306 |
+
y_min_new = max(0, int(y_center - y_range))
|
| 1307 |
+
y_max_new = min(stage_1_pred_cleaned.shape[1] - 1, int(y_center + y_range))
|
| 1308 |
+
x_min_new = max(0, int(x_center - x_range))
|
| 1309 |
+
x_max_new = min(stage_1_pred_cleaned.shape[2] - 1, int(x_center + x_range))
|
| 1310 |
+
|
| 1311 |
+
if verbose:
|
| 1312 |
+
print(f"ROI bounds: z=[{z_min_new}:{z_max_new}], y=[{y_min_new}:{y_max_new}], x=[{x_min_new}:{x_max_new}]")
|
| 1313 |
+
|
| 1314 |
+
roi_array = img_array[z_min_new:z_max_new+1, y_min_new:y_max_new+1, x_min_new:x_max_new+1]
|
| 1315 |
+
roi_lowres_pred = stage_1_pred_cleaned[z_min_new:z_max_new+1, y_min_new:y_max_new+1, x_min_new:x_max_new+1]
|
| 1316 |
+
|
| 1317 |
+
if verbose:
|
| 1318 |
+
print(f"ROI image shape: {roi_array.shape}")
|
| 1319 |
+
print(f"ROI prediction shape: {roi_lowres_pred.shape}")
|
| 1320 |
+
|
| 1321 |
+
# Stage 2: High-resolution segmentation on ROI
|
| 1322 |
+
if verbose:
|
| 1323 |
+
print("Stage 2: High-resolution segmentation on ROI...")
|
| 1324 |
+
roi_pred, _ = run_segmentation(
|
| 1325 |
+
raw_image=roi_array,
|
| 1326 |
+
raw_spacing=img_spacing,
|
| 1327 |
+
crop_size=[192, 192, 192],
|
| 1328 |
+
target_spacing=[1.0, 1.0, 1.0],
|
| 1329 |
+
target_spacing_model=[1.0, 1.0, 1.0],
|
| 1330 |
+
w_lowres_pred_prompts=True,
|
| 1331 |
+
scaled_roi_lowres_pred_array=roi_lowres_pred,
|
| 1332 |
+
disable_tta=True,
|
| 1333 |
+
model_step=341300,
|
| 1334 |
+
modality=modality.lower(),
|
| 1335 |
+
instance_label=0,
|
| 1336 |
+
texts=texts,
|
| 1337 |
+
label_values=label_values,
|
| 1338 |
+
return_max_prob=False,
|
| 1339 |
+
class_name_list=[],
|
| 1340 |
+
stage_1_flag=False,
|
| 1341 |
+
device=device,
|
| 1342 |
+
checkpoints_path=checkpoints_path,
|
| 1343 |
+
verbose=verbose
|
| 1344 |
+
)
|
| 1345 |
+
|
| 1346 |
+
# Integrate ROI prediction back into full volume
|
| 1347 |
+
if verbose:
|
| 1348 |
+
print("Integrating Stage 2 results back into full volume...")
|
| 1349 |
+
final_pred = np.zeros_like(stage_1_pred_cleaned, dtype=np.int16)
|
| 1350 |
+
final_pred[z_min_new:z_max_new+1, y_min_new:y_max_new+1, x_min_new:x_max_new+1] = roi_pred
|
| 1351 |
+
if verbose:
|
| 1352 |
+
print("Stage1+Stage2 inference completed.")
|
| 1353 |
+
elif inference_mode == 'stage2_only':
|
| 1354 |
+
if verbose:
|
| 1355 |
+
print(f"Running Stage 2 inference with {window_type} window...")
|
| 1356 |
+
final_pred, _ = run_segmentation(
|
| 1357 |
+
raw_image=img_array,
|
| 1358 |
+
raw_spacing=img_spacing,
|
| 1359 |
+
crop_size=[192, 192, 192],
|
| 1360 |
+
target_spacing=[1.0, 1.0, 1.0],
|
| 1361 |
+
target_spacing_model=[1.0, 1.0, 1.0],
|
| 1362 |
+
w_lowres_pred_prompts=False,
|
| 1363 |
+
scaled_roi_lowres_pred_array=None,
|
| 1364 |
+
disable_tta=True,
|
| 1365 |
+
model_step=341300,
|
| 1366 |
+
modality=modality.lower(),
|
| 1367 |
+
instance_label=0,
|
| 1368 |
+
texts=texts,
|
| 1369 |
+
label_values=label_values,
|
| 1370 |
+
return_max_prob=False,
|
| 1371 |
+
class_name_list=[],
|
| 1372 |
+
stage_1_flag=False,
|
| 1373 |
+
device=device,
|
| 1374 |
+
checkpoints_path=checkpoints_path,
|
| 1375 |
+
verbose=verbose
|
| 1376 |
+
)
|
| 1377 |
+
else:
|
| 1378 |
+
raise ValueError(f"Unknown inference mode: {inference_mode}. Must be 'stage2_only' or 'stage1+stage2'")
|
| 1379 |
+
|
| 1380 |
+
return final_pred
|
| 1381 |
+
|
| 1382 |
+
|
| 1383 |
+
def run_inference(
|
| 1384 |
+
image_path,
|
| 1385 |
+
output_path,
|
| 1386 |
+
modality='CT',
|
| 1387 |
+
texts=None,
|
| 1388 |
+
label_values=None,
|
| 1389 |
+
inference_mode='stage2_only',
|
| 1390 |
+
device="cuda:0",
|
| 1391 |
+
checkpoints_path="./checkpoints",
|
| 1392 |
+
window_settings=None,
|
| 1393 |
+
window_type='soft_tissue',
|
| 1394 |
+
normalization_settings=None,
|
| 1395 |
+
window_type_mapping=None,
|
| 1396 |
+
verbose=True
|
| 1397 |
+
):
|
| 1398 |
+
"""
|
| 1399 |
+
Run Medal-S inference on a raw NIfTI image.
|
| 1400 |
+
|
| 1401 |
+
Supports multi-window inference for CT images: if multiple window types are specified
|
| 1402 |
+
(e.g., soft_tissue, bone, lung), each window type will be processed separately with
|
| 1403 |
+
its corresponding window settings, and results will be merged.
|
| 1404 |
+
|
| 1405 |
+
Args:
|
| 1406 |
+
image_path: Path to input NIfTI image
|
| 1407 |
+
output_path: Path to save output segmentation (will be modified with mode suffix)
|
| 1408 |
+
modality: Imaging modality ('CT', 'MRI', 'US', 'PET', 'microscopy')
|
| 1409 |
+
texts: List of text prompts (one per class)
|
| 1410 |
+
label_values: List of label values (one per class)
|
| 1411 |
+
inference_mode: Inference mode ('stage2_only' or 'stage1+stage2')
|
| 1412 |
+
device: Device to use ('cuda:0' or 'cpu')
|
| 1413 |
+
checkpoints_path: Path to model checkpoints
|
| 1414 |
+
window_settings: Dictionary containing window settings for different window types (CT only).
|
| 1415 |
+
Format: {'soft_tissue': {'window_level': 40, 'window_width': 400}, ...}
|
| 1416 |
+
window_type: Type of window to use ('soft_tissue', 'bone', 'lung'). Default: 'soft_tissue' (CT only)
|
| 1417 |
+
Ignored if window_type_mapping indicates multiple window types
|
| 1418 |
+
normalization_settings: Dictionary containing normalization settings for non-CT modalities.
|
| 1419 |
+
Format: {'percentile_lower': 0.5, 'percentile_upper': 99.5, 'preserve_zero': True}
|
| 1420 |
+
window_type_mapping: Dictionary mapping each text to its window type.
|
| 1421 |
+
Format: {'text1': 'soft_tissue', 'text2': 'bone', ...}
|
| 1422 |
+
If provided and contains multiple window types, will perform separate inference for each
|
| 1423 |
+
verbose: Whether to print detailed information (default: True)
|
| 1424 |
+
|
| 1425 |
+
Returns:
|
| 1426 |
+
pred_array: Segmentation array (d, h, w)
|
| 1427 |
+
inference_time: Total inference time in seconds
|
| 1428 |
+
"""
|
| 1429 |
+
if texts is None:
|
| 1430 |
+
texts = []
|
| 1431 |
+
if label_values is None:
|
| 1432 |
+
label_values = []
|
| 1433 |
+
|
| 1434 |
+
if len(texts) != len(label_values):
|
| 1435 |
+
raise ValueError("Number of text prompts must match number of label values")
|
| 1436 |
+
|
| 1437 |
+
# Add mode suffix to output filename
|
| 1438 |
+
if inference_mode == 'stage1+stage2':
|
| 1439 |
+
suffix = '_stage1+stage2'
|
| 1440 |
+
elif inference_mode == 'stage2_only':
|
| 1441 |
+
suffix = '_stage2_only'
|
| 1442 |
+
else:
|
| 1443 |
+
suffix = f'_{inference_mode}'
|
| 1444 |
+
|
| 1445 |
+
# Modify output path to include suffix
|
| 1446 |
+
base_path, ext = os.path.splitext(output_path)
|
| 1447 |
+
if ext == '.gz': # Handle .nii.gz
|
| 1448 |
+
base_path, nii_ext = os.path.splitext(base_path)
|
| 1449 |
+
output_path = f"{base_path}{suffix}{nii_ext}{ext}"
|
| 1450 |
+
else:
|
| 1451 |
+
output_path = f"{base_path}{suffix}{ext}"
|
| 1452 |
+
|
| 1453 |
+
if verbose:
|
| 1454 |
+
print(f"Output will be saved to: {output_path}")
|
| 1455 |
+
|
| 1456 |
+
# Start timing
|
| 1457 |
+
start_time = time.time()
|
| 1458 |
+
|
| 1459 |
+
# Load image
|
| 1460 |
+
if verbose:
|
| 1461 |
+
print(f"Loading image: {image_path}")
|
| 1462 |
+
image_data, spacing_xyz, metadata = load_nifti_image(image_path)
|
| 1463 |
+
if verbose:
|
| 1464 |
+
print(f"Image shape: {image_data.shape}")
|
| 1465 |
+
print(f"Original spacing (x, y, z): {spacing_xyz}")
|
| 1466 |
+
|
| 1467 |
+
# Determine inference strategy based on modality and window types
|
| 1468 |
+
if modality.upper() == 'CT':
|
| 1469 |
+
# CT modality: check for multiple window types
|
| 1470 |
+
if window_type_mapping is not None:
|
| 1471 |
+
window_types = list(set(window_type_mapping.values()))
|
| 1472 |
+
if len(window_types) > 1:
|
| 1473 |
+
# Multiple window types: perform separate inference for each window type
|
| 1474 |
+
if verbose:
|
| 1475 |
+
print(f"\n{'='*60}")
|
| 1476 |
+
print(f"CT with {len(window_types)} window types detected: {window_types}")
|
| 1477 |
+
print("Performing separate inference for each window type...")
|
| 1478 |
+
print(f"{'='*60}\n")
|
| 1479 |
+
|
| 1480 |
+
all_predictions = []
|
| 1481 |
+
|
| 1482 |
+
for wt in window_types:
|
| 1483 |
+
if verbose:
|
| 1484 |
+
print(f"\n{'='*60}")
|
| 1485 |
+
print(f"Processing {wt} window type...")
|
| 1486 |
+
print(f"{'='*60}\n")
|
| 1487 |
+
|
| 1488 |
+
# Filter texts and label_values for this window type
|
| 1489 |
+
wt_texts = [text for text in texts if window_type_mapping.get(text) == wt]
|
| 1490 |
+
wt_indices = [i for i, text in enumerate(texts) if window_type_mapping.get(text) == wt]
|
| 1491 |
+
wt_label_values = [label_values[i] for i in wt_indices]
|
| 1492 |
+
|
| 1493 |
+
if len(wt_texts) == 0:
|
| 1494 |
+
if verbose:
|
| 1495 |
+
print(f"No classes for {wt} window type, skipping...")
|
| 1496 |
+
continue
|
| 1497 |
+
|
| 1498 |
+
if verbose:
|
| 1499 |
+
print(f"Classes for {wt} window: {len(wt_texts)}")
|
| 1500 |
+
print(f" Texts: {wt_texts}")
|
| 1501 |
+
print(f" Labels: {wt_label_values}")
|
| 1502 |
+
|
| 1503 |
+
# Run inference for this window type with its specific window settings
|
| 1504 |
+
wt_pred = run_inference_single_window(
|
| 1505 |
+
image_data=image_data,
|
| 1506 |
+
spacing_xyz=spacing_xyz,
|
| 1507 |
+
metadata=metadata,
|
| 1508 |
+
modality=modality,
|
| 1509 |
+
texts=wt_texts,
|
| 1510 |
+
label_values=wt_label_values,
|
| 1511 |
+
inference_mode=inference_mode,
|
| 1512 |
+
device=device,
|
| 1513 |
+
checkpoints_path=checkpoints_path,
|
| 1514 |
+
window_settings=window_settings,
|
| 1515 |
+
window_type=wt, # Use the specific window type
|
| 1516 |
+
normalization_settings=normalization_settings,
|
| 1517 |
+
verbose=verbose
|
| 1518 |
+
)
|
| 1519 |
+
|
| 1520 |
+
all_predictions.append((wt_pred, wt_label_values))
|
| 1521 |
+
|
| 1522 |
+
# Merge predictions: use maximum label value when overlapping
|
| 1523 |
+
if verbose:
|
| 1524 |
+
print(f"\n{'='*60}")
|
| 1525 |
+
print("Merging predictions from all window types...")
|
| 1526 |
+
print(f"{'='*60}\n")
|
| 1527 |
+
|
| 1528 |
+
final_pred = np.zeros_like(all_predictions[0][0], dtype=np.int16)
|
| 1529 |
+
for wt_pred, wt_labels in all_predictions:
|
| 1530 |
+
# For each label in this window type's prediction
|
| 1531 |
+
for label_val in wt_labels:
|
| 1532 |
+
label_int = int(label_val)
|
| 1533 |
+
mask = (wt_pred == label_int)
|
| 1534 |
+
# Only update if current prediction is background (0) or smaller label
|
| 1535 |
+
final_pred[mask] = np.maximum(final_pred[mask], label_int)
|
| 1536 |
+
|
| 1537 |
+
if verbose:
|
| 1538 |
+
print("Merging completed.")
|
| 1539 |
+
else:
|
| 1540 |
+
# Single window type: use the specific window type
|
| 1541 |
+
if len(window_types) == 1:
|
| 1542 |
+
window_type = window_types[0]
|
| 1543 |
+
if verbose:
|
| 1544 |
+
print(f"CT with single window type: {window_type}")
|
| 1545 |
+
|
| 1546 |
+
final_pred = run_inference_single_window(
|
| 1547 |
+
image_data=image_data,
|
| 1548 |
+
spacing_xyz=spacing_xyz,
|
| 1549 |
+
metadata=metadata,
|
| 1550 |
+
modality=modality,
|
| 1551 |
+
texts=texts,
|
| 1552 |
+
label_values=label_values,
|
| 1553 |
+
inference_mode=inference_mode,
|
| 1554 |
+
device=device,
|
| 1555 |
+
checkpoints_path=checkpoints_path,
|
| 1556 |
+
window_settings=window_settings,
|
| 1557 |
+
window_type=window_type, # Use the determined window type
|
| 1558 |
+
normalization_settings=normalization_settings,
|
| 1559 |
+
verbose=verbose
|
| 1560 |
+
)
|
| 1561 |
+
else:
|
| 1562 |
+
# No window_type_mapping: use default window_type
|
| 1563 |
+
if verbose:
|
| 1564 |
+
print(f"CT without window_type_mapping, using window type: {window_type}")
|
| 1565 |
+
final_pred = run_inference_single_window(
|
| 1566 |
+
image_data=image_data,
|
| 1567 |
+
spacing_xyz=spacing_xyz,
|
| 1568 |
+
metadata=metadata,
|
| 1569 |
+
modality=modality,
|
| 1570 |
+
texts=texts,
|
| 1571 |
+
label_values=label_values,
|
| 1572 |
+
inference_mode=inference_mode,
|
| 1573 |
+
device=device,
|
| 1574 |
+
checkpoints_path=checkpoints_path,
|
| 1575 |
+
window_settings=window_settings,
|
| 1576 |
+
window_type=window_type,
|
| 1577 |
+
normalization_settings=normalization_settings,
|
| 1578 |
+
verbose=verbose
|
| 1579 |
+
)
|
| 1580 |
+
else:
|
| 1581 |
+
# Non-CT modality: use normalization_settings (other normalization)
|
| 1582 |
+
if verbose:
|
| 1583 |
+
print(f"Non-CT modality ({modality}): using normalization_settings")
|
| 1584 |
+
final_pred = run_inference_single_window(
|
| 1585 |
+
image_data=image_data,
|
| 1586 |
+
spacing_xyz=spacing_xyz,
|
| 1587 |
+
metadata=metadata,
|
| 1588 |
+
modality=modality,
|
| 1589 |
+
texts=texts,
|
| 1590 |
+
label_values=label_values,
|
| 1591 |
+
inference_mode=inference_mode,
|
| 1592 |
+
device=device,
|
| 1593 |
+
checkpoints_path=checkpoints_path,
|
| 1594 |
+
window_settings=window_settings, # Not used for non-CT
|
| 1595 |
+
window_type=window_type, # Not used for non-CT
|
| 1596 |
+
normalization_settings=normalization_settings, # Used for non-CT
|
| 1597 |
+
verbose=verbose
|
| 1598 |
+
)
|
| 1599 |
+
|
| 1600 |
+
# End timing
|
| 1601 |
+
end_time = time.time()
|
| 1602 |
+
inference_time = end_time - start_time
|
| 1603 |
+
|
| 1604 |
+
if verbose:
|
| 1605 |
+
print(f"\n{'='*60}")
|
| 1606 |
+
print(f"Inference Mode: {inference_mode}")
|
| 1607 |
+
print(f"Total Inference Time: {inference_time:.2f} seconds ({inference_time/60:.2f} minutes)")
|
| 1608 |
+
print(f"{'='*60}\n")
|
| 1609 |
+
|
| 1610 |
+
# Save result
|
| 1611 |
+
if verbose:
|
| 1612 |
+
print(f"Saving segmentation to: {output_path}")
|
| 1613 |
+
seg_sitk = sitk.GetImageFromArray(final_pred.astype(np.int16))
|
| 1614 |
+
seg_sitk.SetSpacing(metadata['spacing_xyz'])
|
| 1615 |
+
seg_sitk.SetOrigin(metadata['origin'])
|
| 1616 |
+
seg_sitk.SetDirection(metadata['direction'])
|
| 1617 |
+
sitk.WriteImage(seg_sitk, output_path)
|
| 1618 |
+
if verbose:
|
| 1619 |
+
print(f"Successfully saved segmentation to: {output_path}")
|
| 1620 |
+
|
| 1621 |
+
return final_pred, inference_time
|
| 1622 |
+
|
| 1623 |
+
|
| 1624 |
+
def load_config_from_json(config_path):
|
| 1625 |
+
"""
|
| 1626 |
+
Load configuration from JSON file.
|
| 1627 |
+
|
| 1628 |
+
Supports two formats:
|
| 1629 |
+
1. Legacy format: single 'texts' array
|
| 1630 |
+
2. New format: separate arrays for 'texts_soft_tissue', 'texts_bone', 'texts_lung'
|
| 1631 |
+
|
| 1632 |
+
If 'labels' field is missing or empty, automatically generates consecutive
|
| 1633 |
+
integer labels starting from 1 (i.e., [1, 2, 3, ..., n] where n is the
|
| 1634 |
+
number of texts).
|
| 1635 |
+
|
| 1636 |
+
Args:
|
| 1637 |
+
config_path: Path to JSON configuration file
|
| 1638 |
+
|
| 1639 |
+
Returns:
|
| 1640 |
+
config: Dictionary containing configuration parameters with processed labels
|
| 1641 |
+
|
| 1642 |
+
Example:
|
| 1643 |
+
# Legacy format:
|
| 1644 |
+
{"texts": ["Aorta", "Liver"], "labels": [1, 2]}
|
| 1645 |
+
|
| 1646 |
+
# New format with window types:
|
| 1647 |
+
{
|
| 1648 |
+
"texts_soft_tissue": ["Aorta", "Liver"],
|
| 1649 |
+
"texts_bone": ["Vertebrae C1"],
|
| 1650 |
+
"texts_lung": ["Left lung"],
|
| 1651 |
+
"window_settings": {
|
| 1652 |
+
"soft_tissue": {"window_level": 40, "window_width": 400},
|
| 1653 |
+
"bone": {"window_level": 400, "window_width": 1500},
|
| 1654 |
+
"lung": {"window_level": -600, "window_width": 1500}
|
| 1655 |
+
}
|
| 1656 |
+
}
|
| 1657 |
+
"""
|
| 1658 |
+
with open(config_path, 'r', encoding='utf-8') as f:
|
| 1659 |
+
config = json.load(f)
|
| 1660 |
+
|
| 1661 |
+
# Check if using new format (separate window types)
|
| 1662 |
+
has_window_types = any(key in config for key in ['texts_soft_tissue', 'texts_bone', 'texts_lung'])
|
| 1663 |
+
|
| 1664 |
+
if has_window_types:
|
| 1665 |
+
# New format: combine all texts from different window types
|
| 1666 |
+
texts_soft_tissue = config.get('texts_soft_tissue', [])
|
| 1667 |
+
texts_bone = config.get('texts_bone', [])
|
| 1668 |
+
texts_lung = config.get('texts_lung', [])
|
| 1669 |
+
|
| 1670 |
+
# Combine all texts in order: soft_tissue, bone, lung
|
| 1671 |
+
texts = texts_soft_tissue + texts_bone + texts_lung
|
| 1672 |
+
|
| 1673 |
+
# Store window type mapping for each text
|
| 1674 |
+
window_type_mapping = {}
|
| 1675 |
+
for text in texts_soft_tissue:
|
| 1676 |
+
window_type_mapping[text] = 'soft_tissue'
|
| 1677 |
+
for text in texts_bone:
|
| 1678 |
+
window_type_mapping[text] = 'bone'
|
| 1679 |
+
for text in texts_lung:
|
| 1680 |
+
window_type_mapping[text] = 'lung'
|
| 1681 |
+
|
| 1682 |
+
config['texts'] = texts
|
| 1683 |
+
config['window_type_mapping'] = window_type_mapping
|
| 1684 |
+
else:
|
| 1685 |
+
# Legacy format: single texts array
|
| 1686 |
+
texts = config.get('texts', [])
|
| 1687 |
+
# Default all texts to soft_tissue window type for backward compatibility
|
| 1688 |
+
window_type_mapping = {text: 'soft_tissue' for text in texts}
|
| 1689 |
+
config['window_type_mapping'] = window_type_mapping
|
| 1690 |
+
|
| 1691 |
+
# Process labels: auto-generate if missing or empty
|
| 1692 |
+
texts = config.get('texts', [])
|
| 1693 |
+
labels = config.get('labels', None)
|
| 1694 |
+
|
| 1695 |
+
if labels is None or len(labels) == 0:
|
| 1696 |
+
# Auto-generate consecutive labels starting from 1
|
| 1697 |
+
labels = list(range(1, len(texts) + 1))
|
| 1698 |
+
print(f" Auto-generated consecutive labels: {labels}")
|
| 1699 |
+
else:
|
| 1700 |
+
# Convert labels to integers (handle both string and integer inputs)
|
| 1701 |
+
labels = [int(label) for label in labels]
|
| 1702 |
+
|
| 1703 |
+
# Validate that number of labels matches number of texts
|
| 1704 |
+
if len(labels) != len(texts):
|
| 1705 |
+
raise ValueError(
|
| 1706 |
+
f"Number of labels ({len(labels)}) must match number of texts ({len(texts)}). "
|
| 1707 |
+
f"Texts: {len(texts)}, Labels: {len(labels)}"
|
| 1708 |
+
)
|
| 1709 |
+
|
| 1710 |
+
config['labels'] = labels
|
| 1711 |
+
return config
|
| 1712 |
+
|
| 1713 |
+
|
| 1714 |
+
def main():
|
| 1715 |
+
"""
|
| 1716 |
+
Main entry point for the inference script.
|
| 1717 |
+
|
| 1718 |
+
Parses command-line arguments and runs inference with the specified
|
| 1719 |
+
configuration.
|
| 1720 |
+
"""
|
| 1721 |
+
parser = argparse.ArgumentParser(
|
| 1722 |
+
description="Medal-S inference for raw NIfTI images",
|
| 1723 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 1724 |
+
epilog="""
|
| 1725 |
+
Examples:
|
| 1726 |
+
# Using JSON configuration file:
|
| 1727 |
+
python inference_medals.py --input image.nii.gz --output result.nii.gz \\
|
| 1728 |
+
--config config.json --mode stage2_only
|
| 1729 |
+
|
| 1730 |
+
# Using command-line arguments:
|
| 1731 |
+
python inference_medals.py --input image.nii.gz --output result.nii.gz \\
|
| 1732 |
+
--modality CT --texts "Aorta in CT" --labels 1 --mode stage1+stage2
|
| 1733 |
+
"""
|
| 1734 |
+
)
|
| 1735 |
+
parser.add_argument(
|
| 1736 |
+
"--input", "-i",
|
| 1737 |
+
type=str,
|
| 1738 |
+
required=True,
|
| 1739 |
+
help="Path to input NIfTI image"
|
| 1740 |
+
)
|
| 1741 |
+
parser.add_argument(
|
| 1742 |
+
"--output", "-o",
|
| 1743 |
+
type=str,
|
| 1744 |
+
required=True,
|
| 1745 |
+
help="Path to save output segmentation (suffix will be added automatically based on inference mode)"
|
| 1746 |
+
)
|
| 1747 |
+
parser.add_argument(
|
| 1748 |
+
"--config", "-c",
|
| 1749 |
+
type=str,
|
| 1750 |
+
default=None,
|
| 1751 |
+
help="Path to JSON configuration file (if provided, will override --texts, --labels, --modality)"
|
| 1752 |
+
)
|
| 1753 |
+
parser.add_argument(
|
| 1754 |
+
"--modality", "-m",
|
| 1755 |
+
type=str,
|
| 1756 |
+
default="CT",
|
| 1757 |
+
choices=['CT', 'MRI', 'US', 'PET', 'microscopy'],
|
| 1758 |
+
help="Imaging modality (default: CT, ignored if --config is provided)"
|
| 1759 |
+
)
|
| 1760 |
+
parser.add_argument(
|
| 1761 |
+
"--texts",
|
| 1762 |
+
type=str,
|
| 1763 |
+
nargs='+',
|
| 1764 |
+
default=None,
|
| 1765 |
+
help="Text prompts (one per class, ignored if --config is provided)"
|
| 1766 |
+
)
|
| 1767 |
+
parser.add_argument(
|
| 1768 |
+
"--labels",
|
| 1769 |
+
type=str,
|
| 1770 |
+
nargs='+',
|
| 1771 |
+
default=None,
|
| 1772 |
+
help="Label values (one per class, must match texts, ignored if --config is provided)"
|
| 1773 |
+
)
|
| 1774 |
+
parser.add_argument(
|
| 1775 |
+
"--mode",
|
| 1776 |
+
type=str,
|
| 1777 |
+
default="stage2_only",
|
| 1778 |
+
choices=['stage2_only', 'stage1+stage2'],
|
| 1779 |
+
help="Inference mode: 'stage2_only' (default) or 'stage1+stage2'"
|
| 1780 |
+
)
|
| 1781 |
+
parser.add_argument(
|
| 1782 |
+
"--device",
|
| 1783 |
+
type=str,
|
| 1784 |
+
default="cuda:0",
|
| 1785 |
+
help="Device to use (default: cuda:0)"
|
| 1786 |
+
)
|
| 1787 |
+
parser.add_argument(
|
| 1788 |
+
"--checkpoints",
|
| 1789 |
+
type=str,
|
| 1790 |
+
default="./checkpoints",
|
| 1791 |
+
help="Path to model checkpoints (default: ./checkpoints)"
|
| 1792 |
+
)
|
| 1793 |
+
parser.add_argument(
|
| 1794 |
+
"--verbose", "-v",
|
| 1795 |
+
action='store_true',
|
| 1796 |
+
default=False,
|
| 1797 |
+
help="Print detailed information during inference (default: False)"
|
| 1798 |
+
)
|
| 1799 |
+
|
| 1800 |
+
args = parser.parse_args()
|
| 1801 |
+
verbose = args.verbose
|
| 1802 |
+
|
| 1803 |
+
# Load configuration from JSON file if provided
|
| 1804 |
+
window_settings = None
|
| 1805 |
+
window_type = 'soft_tissue'
|
| 1806 |
+
normalization_settings = None
|
| 1807 |
+
window_type_mapping = None
|
| 1808 |
+
|
| 1809 |
+
if args.config:
|
| 1810 |
+
if not os.path.exists(args.config):
|
| 1811 |
+
raise FileNotFoundError(f"Configuration file not found: {args.config}")
|
| 1812 |
+
config = load_config_from_json(args.config)
|
| 1813 |
+
texts = config.get('texts', [])
|
| 1814 |
+
labels = config.get('labels', [])
|
| 1815 |
+
modality = config.get('modality', 'CT')
|
| 1816 |
+
window_settings = config.get('window_settings')
|
| 1817 |
+
normalization_settings = config.get('normalization_settings')
|
| 1818 |
+
window_type_mapping = config.get('window_type_mapping')
|
| 1819 |
+
|
| 1820 |
+
# Determine default window type based on texts (for CT only, used as fallback)
|
| 1821 |
+
if modality.upper() == 'CT':
|
| 1822 |
+
if window_type_mapping:
|
| 1823 |
+
window_types = list(set(window_type_mapping.values()))
|
| 1824 |
+
if len(window_types) == 1:
|
| 1825 |
+
window_type = window_types[0]
|
| 1826 |
+
else:
|
| 1827 |
+
# Default to soft_tissue if mixed types (will be handled by multi-window inference)
|
| 1828 |
+
window_type = 'soft_tissue'
|
| 1829 |
+
|
| 1830 |
+
# Convert labels to strings for compatibility with run_segmentation
|
| 1831 |
+
# (run_segmentation expects string labels)
|
| 1832 |
+
label_values = [str(label) for label in labels]
|
| 1833 |
+
|
| 1834 |
+
if verbose:
|
| 1835 |
+
print(f"Loaded configuration from: {args.config}")
|
| 1836 |
+
print(f" Modality: {modality}")
|
| 1837 |
+
print(f" Number of classes: {len(texts)}")
|
| 1838 |
+
print(f" Labels: {labels}")
|
| 1839 |
+
if modality.upper() == 'CT' and window_settings:
|
| 1840 |
+
print(f" Window settings available for: {list(window_settings.keys())}")
|
| 1841 |
+
if window_type_mapping:
|
| 1842 |
+
window_types = list(set(window_type_mapping.values()))
|
| 1843 |
+
if len(window_types) > 1:
|
| 1844 |
+
print(f" Multiple window types detected: {window_types}")
|
| 1845 |
+
print(f" Will perform separate inference for each window type")
|
| 1846 |
+
else:
|
| 1847 |
+
print(f" Using window type: {window_type}")
|
| 1848 |
+
else:
|
| 1849 |
+
print(f" Using window type: {window_type}")
|
| 1850 |
+
elif normalization_settings:
|
| 1851 |
+
print(f" Normalization settings: {normalization_settings}")
|
| 1852 |
+
else:
|
| 1853 |
+
# Use command line arguments
|
| 1854 |
+
if args.texts is None or args.labels is None:
|
| 1855 |
+
raise ValueError("Either --config or both --texts and --labels must be provided")
|
| 1856 |
+
texts = args.texts
|
| 1857 |
+
label_values = args.labels
|
| 1858 |
+
modality = args.modality
|
| 1859 |
+
|
| 1860 |
+
# Create output directory if needed
|
| 1861 |
+
output_dir = os.path.dirname(args.output)
|
| 1862 |
+
if output_dir and not os.path.exists(output_dir):
|
| 1863 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 1864 |
+
|
| 1865 |
+
# Run inference
|
| 1866 |
+
run_inference(
|
| 1867 |
+
image_path=args.input,
|
| 1868 |
+
output_path=args.output,
|
| 1869 |
+
modality=modality,
|
| 1870 |
+
texts=texts,
|
| 1871 |
+
label_values=label_values,
|
| 1872 |
+
inference_mode=args.mode,
|
| 1873 |
+
device=args.device,
|
| 1874 |
+
checkpoints_path=args.checkpoints,
|
| 1875 |
+
window_settings=window_settings,
|
| 1876 |
+
window_type=window_type,
|
| 1877 |
+
normalization_settings=normalization_settings,
|
| 1878 |
+
window_type_mapping=window_type_mapping,
|
| 1879 |
+
verbose=verbose
|
| 1880 |
+
)
|
| 1881 |
+
|
| 1882 |
+
|
| 1883 |
+
if __name__ == '__main__':
|
| 1884 |
+
main()
|
| 1885 |
+
|
model/SwinUNETR.py
ADDED
|
@@ -0,0 +1,1116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Sequence, Tuple, Type, Union
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import torch.utils.checkpoint as checkpoint
|
| 8 |
+
from torch.nn import LayerNorm
|
| 9 |
+
|
| 10 |
+
from monai.networks.blocks import MLPBlock as Mlp
|
| 11 |
+
from monai.networks.blocks import PatchEmbed, UnetOutBlock, UnetrBasicBlock, UnetrUpBlock
|
| 12 |
+
from monai.networks.layers import DropPath, trunc_normal_
|
| 13 |
+
from monai.utils import ensure_tuple_rep, optional_import
|
| 14 |
+
|
| 15 |
+
rearrange, _ = optional_import("einops", name="rearrange")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class SwinUNETR_Enc(nn.Module):
|
| 19 |
+
"""
|
| 20 |
+
Swin UNETR based on: "Hatamizadeh et al.,
|
| 21 |
+
Swin UNETR: Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images
|
| 22 |
+
<https://arxiv.org/abs/2201.01266>"
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
img_size: Union[Sequence[int], int],
|
| 28 |
+
in_channels: int,
|
| 29 |
+
depths: Sequence[int] = (2, 2, 2, 2),
|
| 30 |
+
num_heads: Sequence[int] = (3, 6, 12, 24),
|
| 31 |
+
feature_size: int = 24,
|
| 32 |
+
norm_name: Union[Tuple, str] = "instance",
|
| 33 |
+
drop_rate: float = 0.0,
|
| 34 |
+
attn_drop_rate: float = 0.0,
|
| 35 |
+
dropout_path_rate: float = 0.0,
|
| 36 |
+
normalize: bool = True,
|
| 37 |
+
use_checkpoint: bool = False,
|
| 38 |
+
spatial_dims: int = 3,
|
| 39 |
+
return_skips: bool = True,
|
| 40 |
+
) -> None:
|
| 41 |
+
"""
|
| 42 |
+
Args:
|
| 43 |
+
img_size: dimension of input image.
|
| 44 |
+
in_channels: dimension of input channels.
|
| 45 |
+
out_channels: dimension of output channels.
|
| 46 |
+
feature_size: dimension of network feature size.
|
| 47 |
+
depths: number of layers in each stage.
|
| 48 |
+
num_heads: number of attention heads.
|
| 49 |
+
norm_name: feature normalization type and arguments.
|
| 50 |
+
drop_rate: dropout rate.
|
| 51 |
+
attn_drop_rate: attention dropout rate.
|
| 52 |
+
dropout_path_rate: drop path rate.
|
| 53 |
+
normalize: normalize output intermediate features in each stage.
|
| 54 |
+
use_checkpoint: use gradient checkpointing for reduced memory usage.
|
| 55 |
+
spatial_dims: number of spatial dims.
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
super().__init__()
|
| 59 |
+
|
| 60 |
+
self.return_skips = return_skips
|
| 61 |
+
|
| 62 |
+
img_size = ensure_tuple_rep(img_size, spatial_dims)
|
| 63 |
+
patch_size = ensure_tuple_rep(2, spatial_dims)
|
| 64 |
+
window_size = ensure_tuple_rep(7, spatial_dims)
|
| 65 |
+
|
| 66 |
+
if not (spatial_dims == 2 or spatial_dims == 3):
|
| 67 |
+
raise ValueError("spatial dimension should be 2 or 3.")
|
| 68 |
+
|
| 69 |
+
for m, p in zip(img_size, patch_size):
|
| 70 |
+
for i in range(5):
|
| 71 |
+
if m % np.power(p, i + 1) != 0:
|
| 72 |
+
raise ValueError("input image size (img_size) should be divisible by stage-wise image resolution.")
|
| 73 |
+
|
| 74 |
+
if not (0 <= drop_rate <= 1):
|
| 75 |
+
raise ValueError("dropout rate should be between 0 and 1.")
|
| 76 |
+
|
| 77 |
+
if not (0 <= attn_drop_rate <= 1):
|
| 78 |
+
raise ValueError("attention dropout rate should be between 0 and 1.")
|
| 79 |
+
|
| 80 |
+
if not (0 <= dropout_path_rate <= 1):
|
| 81 |
+
raise ValueError("drop path rate should be between 0 and 1.")
|
| 82 |
+
|
| 83 |
+
if feature_size % 12 != 0:
|
| 84 |
+
raise ValueError("feature_size should be divisible by 12.")
|
| 85 |
+
|
| 86 |
+
self.normalize = normalize
|
| 87 |
+
|
| 88 |
+
self.swinViT = SwinTransformer(
|
| 89 |
+
in_chans=in_channels,
|
| 90 |
+
embed_dim=feature_size,
|
| 91 |
+
window_size=window_size,
|
| 92 |
+
patch_size=patch_size,
|
| 93 |
+
depths=depths,
|
| 94 |
+
num_heads=num_heads,
|
| 95 |
+
mlp_ratio=4.0,
|
| 96 |
+
qkv_bias=True,
|
| 97 |
+
drop_rate=drop_rate,
|
| 98 |
+
attn_drop_rate=attn_drop_rate,
|
| 99 |
+
drop_path_rate=dropout_path_rate,
|
| 100 |
+
norm_layer=nn.LayerNorm,
|
| 101 |
+
use_checkpoint=use_checkpoint,
|
| 102 |
+
spatial_dims=spatial_dims,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
self.encoder1 = UnetrBasicBlock( # 2 conv layers
|
| 106 |
+
spatial_dims=spatial_dims,
|
| 107 |
+
in_channels=in_channels,
|
| 108 |
+
out_channels=feature_size,
|
| 109 |
+
kernel_size=3,
|
| 110 |
+
stride=1,
|
| 111 |
+
norm_name=norm_name,
|
| 112 |
+
res_block=True,
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
self.encoder2 = UnetrBasicBlock(
|
| 116 |
+
spatial_dims=spatial_dims,
|
| 117 |
+
in_channels=feature_size,
|
| 118 |
+
out_channels=feature_size,
|
| 119 |
+
kernel_size=3,
|
| 120 |
+
stride=1,
|
| 121 |
+
norm_name=norm_name,
|
| 122 |
+
res_block=True,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
self.encoder3 = UnetrBasicBlock(
|
| 126 |
+
spatial_dims=spatial_dims,
|
| 127 |
+
in_channels=2 * feature_size,
|
| 128 |
+
out_channels=2 * feature_size,
|
| 129 |
+
kernel_size=3,
|
| 130 |
+
stride=1,
|
| 131 |
+
norm_name=norm_name,
|
| 132 |
+
res_block=True,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
self.encoder4 = UnetrBasicBlock(
|
| 136 |
+
spatial_dims=spatial_dims,
|
| 137 |
+
in_channels=4 * feature_size,
|
| 138 |
+
out_channels=4 * feature_size,
|
| 139 |
+
kernel_size=3,
|
| 140 |
+
stride=1,
|
| 141 |
+
norm_name=norm_name,
|
| 142 |
+
res_block=True,
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
self.encoder5 = UnetrBasicBlock(
|
| 146 |
+
spatial_dims=spatial_dims,
|
| 147 |
+
in_channels=8 * feature_size,
|
| 148 |
+
out_channels=8 * feature_size,
|
| 149 |
+
kernel_size=3,
|
| 150 |
+
stride=1,
|
| 151 |
+
norm_name=norm_name,
|
| 152 |
+
res_block=True,
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
self.encoder6 = UnetrBasicBlock(
|
| 156 |
+
spatial_dims=spatial_dims,
|
| 157 |
+
in_channels=16 * feature_size,
|
| 158 |
+
out_channels=16 * feature_size,
|
| 159 |
+
kernel_size=3,
|
| 160 |
+
stride=1,
|
| 161 |
+
norm_name=norm_name,
|
| 162 |
+
res_block=True,
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
def load_from(self, weights):
|
| 166 |
+
|
| 167 |
+
with torch.no_grad():
|
| 168 |
+
self.swinViT.patch_embed.proj.weight.copy_(weights["state_dict"]["module.patch_embed.proj.weight"])
|
| 169 |
+
self.swinViT.patch_embed.proj.bias.copy_(weights["state_dict"]["module.patch_embed.proj.bias"])
|
| 170 |
+
for bname, block in self.swinViT.layers1[0].blocks.named_children():
|
| 171 |
+
block.load_from(weights, n_block=bname, layer="layers1")
|
| 172 |
+
self.swinViT.layers1[0].downsample.reduction.weight.copy_(
|
| 173 |
+
weights["state_dict"]["module.layers1.0.downsample.reduction.weight"]
|
| 174 |
+
)
|
| 175 |
+
self.swinViT.layers1[0].downsample.norm.weight.copy_(
|
| 176 |
+
weights["state_dict"]["module.layers1.0.downsample.norm.weight"]
|
| 177 |
+
)
|
| 178 |
+
self.swinViT.layers1[0].downsample.norm.bias.copy_(
|
| 179 |
+
weights["state_dict"]["module.layers1.0.downsample.norm.bias"]
|
| 180 |
+
)
|
| 181 |
+
for bname, block in self.swinViT.layers2[0].blocks.named_children():
|
| 182 |
+
block.load_from(weights, n_block=bname, layer="layers2")
|
| 183 |
+
self.swinViT.layers2[0].downsample.reduction.weight.copy_(
|
| 184 |
+
weights["state_dict"]["module.layers2.0.downsample.reduction.weight"]
|
| 185 |
+
)
|
| 186 |
+
self.swinViT.layers2[0].downsample.norm.weight.copy_(
|
| 187 |
+
weights["state_dict"]["module.layers2.0.downsample.norm.weight"]
|
| 188 |
+
)
|
| 189 |
+
self.swinViT.layers2[0].downsample.norm.bias.copy_(
|
| 190 |
+
weights["state_dict"]["module.layers2.0.downsample.norm.bias"]
|
| 191 |
+
)
|
| 192 |
+
for bname, block in self.swinViT.layers3[0].blocks.named_children():
|
| 193 |
+
block.load_from(weights, n_block=bname, layer="layers3")
|
| 194 |
+
self.swinViT.layers3[0].downsample.reduction.weight.copy_(
|
| 195 |
+
weights["state_dict"]["module.layers3.0.downsample.reduction.weight"]
|
| 196 |
+
)
|
| 197 |
+
self.swinViT.layers3[0].downsample.norm.weight.copy_(
|
| 198 |
+
weights["state_dict"]["module.layers3.0.downsample.norm.weight"]
|
| 199 |
+
)
|
| 200 |
+
self.swinViT.layers3[0].downsample.norm.bias.copy_(
|
| 201 |
+
weights["state_dict"]["module.layers3.0.downsample.norm.bias"]
|
| 202 |
+
)
|
| 203 |
+
for bname, block in self.swinViT.layers4[0].blocks.named_children():
|
| 204 |
+
block.load_from(weights, n_block=bname, layer="layers4")
|
| 205 |
+
self.swinViT.layers4[0].downsample.reduction.weight.copy_(
|
| 206 |
+
weights["state_dict"]["module.layers4.0.downsample.reduction.weight"]
|
| 207 |
+
)
|
| 208 |
+
self.swinViT.layers4[0].downsample.norm.weight.copy_(
|
| 209 |
+
weights["state_dict"]["module.layers4.0.downsample.norm.weight"]
|
| 210 |
+
)
|
| 211 |
+
self.swinViT.layers4[0].downsample.norm.bias.copy_(
|
| 212 |
+
weights["state_dict"]["module.layers4.0.downsample.norm.bias"]
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
def forward(self, x_in):
|
| 216 |
+
# print(x_in.shape, task_id.shape)
|
| 217 |
+
hidden_states_out = self.swinViT(x_in, self.normalize)
|
| 218 |
+
|
| 219 |
+
enc0 = self.encoder1(x_in)
|
| 220 |
+
enc1 = self.encoder2(hidden_states_out[0])
|
| 221 |
+
enc2 = self.encoder3(hidden_states_out[1])
|
| 222 |
+
enc3 = self.encoder4(hidden_states_out[2])
|
| 223 |
+
enc4 = self.encoder5(hidden_states_out[3])
|
| 224 |
+
dec4 = self.encoder6(hidden_states_out[4])
|
| 225 |
+
# print(x_in.shape, enc0.shape, enc1.shape, enc2.shape, enc3.shape, dec4.shape)
|
| 226 |
+
# torch.Size([6, 1, 64, 64, 64]) torch.Size([6, 48, 64, 64, 64]) torch.Size([6, 48, 32, 32, 32])
|
| 227 |
+
# torch.Size([6, 96, 16, 16, 16]) torch.Size([6, 192, 8,8, 8]) torch.Size([6, 768, 2, 2, 2])
|
| 228 |
+
|
| 229 |
+
if self.return_skips:
|
| 230 |
+
return [enc0, enc1, enc2, enc3, enc4, dec4]
|
| 231 |
+
else:
|
| 232 |
+
return [dec4]
|
| 233 |
+
|
| 234 |
+
class SwinUNETR(nn.Module):
|
| 235 |
+
"""
|
| 236 |
+
Swin UNETR based on: "Hatamizadeh et al.,
|
| 237 |
+
Swin UNETR: Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images
|
| 238 |
+
<https://arxiv.org/abs/2201.01266>"
|
| 239 |
+
"""
|
| 240 |
+
|
| 241 |
+
def __init__(
|
| 242 |
+
self,
|
| 243 |
+
img_size: Union[Sequence[int], int],
|
| 244 |
+
in_channels: int,
|
| 245 |
+
depths: Sequence[int] = (2, 2, 2, 2),
|
| 246 |
+
num_heads: Sequence[int] = (3, 6, 12, 24),
|
| 247 |
+
feature_size: int = 24,
|
| 248 |
+
norm_name: Union[Tuple, str] = "instance",
|
| 249 |
+
drop_rate: float = 0.0,
|
| 250 |
+
attn_drop_rate: float = 0.0,
|
| 251 |
+
dropout_path_rate: float = 0.0,
|
| 252 |
+
normalize: bool = True,
|
| 253 |
+
use_checkpoint: bool = False,
|
| 254 |
+
spatial_dims: int = 3,
|
| 255 |
+
encoding: Union[Tuple, str] = 'rand_embedding', ## rand_embedding or word_embedding
|
| 256 |
+
deep_supervision: bool = True,
|
| 257 |
+
return_skips: bool = True,
|
| 258 |
+
) -> None:
|
| 259 |
+
"""
|
| 260 |
+
Args:
|
| 261 |
+
img_size: dimension of input image.
|
| 262 |
+
in_channels: dimension of input channels.
|
| 263 |
+
out_channels: dimension of output channels.
|
| 264 |
+
feature_size: dimension of network feature size.
|
| 265 |
+
depths: number of layers in each stage.
|
| 266 |
+
num_heads: number of attention heads.
|
| 267 |
+
norm_name: feature normalization type and arguments.
|
| 268 |
+
drop_rate: dropout rate.
|
| 269 |
+
attn_drop_rate: attention dropout rate.
|
| 270 |
+
dropout_path_rate: drop path rate.
|
| 271 |
+
normalize: normalize output intermediate features in each stage.
|
| 272 |
+
use_checkpoint: use gradient checkpointing for reduced memory usage.
|
| 273 |
+
spatial_dims: number of spatial dims.
|
| 274 |
+
Examples::
|
| 275 |
+
# for 3D single channel input with size (96,96,96), 4-channel output and feature size of 48.
|
| 276 |
+
>>> net = SwinUNETR(img_size=(96,96,96), in_channels=1, out_channels=4, feature_size=48)
|
| 277 |
+
# for 3D 4-channel input with size (128,128,128), 3-channel output and (2,4,2,2) layers in each stage.
|
| 278 |
+
>>> net = SwinUNETR(img_size=(128,128,128), in_channels=4, out_channels=3, depths=(2,4,2,2))
|
| 279 |
+
# for 2D single channel input with size (96,96), 2-channel output and gradient checkpointing.
|
| 280 |
+
>>> net = SwinUNETR(img_size=(96,96), in_channels=3, out_channels=2, use_checkpoint=True, spatial_dims=2)
|
| 281 |
+
"""
|
| 282 |
+
|
| 283 |
+
super().__init__()
|
| 284 |
+
|
| 285 |
+
self.deep_supervision = deep_supervision
|
| 286 |
+
self.return_skips = return_skips
|
| 287 |
+
|
| 288 |
+
self.encoding = encoding
|
| 289 |
+
|
| 290 |
+
img_size = ensure_tuple_rep(img_size, spatial_dims)
|
| 291 |
+
patch_size = ensure_tuple_rep(2, spatial_dims)
|
| 292 |
+
window_size = ensure_tuple_rep(7, spatial_dims)
|
| 293 |
+
|
| 294 |
+
if not (spatial_dims == 2 or spatial_dims == 3):
|
| 295 |
+
raise ValueError("spatial dimension should be 2 or 3.")
|
| 296 |
+
|
| 297 |
+
for m, p in zip(img_size, patch_size):
|
| 298 |
+
for i in range(5):
|
| 299 |
+
if m % np.power(p, i + 1) != 0:
|
| 300 |
+
raise ValueError("input image size (img_size) should be divisible by stage-wise image resolution.")
|
| 301 |
+
|
| 302 |
+
if not (0 <= drop_rate <= 1):
|
| 303 |
+
raise ValueError("dropout rate should be between 0 and 1.")
|
| 304 |
+
|
| 305 |
+
if not (0 <= attn_drop_rate <= 1):
|
| 306 |
+
raise ValueError("attention dropout rate should be between 0 and 1.")
|
| 307 |
+
|
| 308 |
+
if not (0 <= dropout_path_rate <= 1):
|
| 309 |
+
raise ValueError("drop path rate should be between 0 and 1.")
|
| 310 |
+
|
| 311 |
+
if feature_size % 12 != 0:
|
| 312 |
+
raise ValueError("feature_size should be divisible by 12.")
|
| 313 |
+
|
| 314 |
+
self.normalize = normalize
|
| 315 |
+
|
| 316 |
+
self.encoder = SwinUNETR_Enc(
|
| 317 |
+
img_size,
|
| 318 |
+
in_channels,
|
| 319 |
+
depths,
|
| 320 |
+
num_heads,
|
| 321 |
+
feature_size,
|
| 322 |
+
norm_name,
|
| 323 |
+
drop_rate,
|
| 324 |
+
attn_drop_rate,
|
| 325 |
+
dropout_path_rate,
|
| 326 |
+
normalize,
|
| 327 |
+
use_checkpoint,
|
| 328 |
+
spatial_dims,
|
| 329 |
+
return_skips=True
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
self.decoder5 = UnetrUpBlock( # a transpose conv layer and 2 conv layers
|
| 333 |
+
spatial_dims=spatial_dims,
|
| 334 |
+
in_channels=16 * feature_size,
|
| 335 |
+
out_channels=8 * feature_size,
|
| 336 |
+
kernel_size=3,
|
| 337 |
+
upsample_kernel_size=2,
|
| 338 |
+
norm_name=norm_name,
|
| 339 |
+
res_block=True,
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
self.decoder4 = UnetrUpBlock(
|
| 343 |
+
spatial_dims=spatial_dims,
|
| 344 |
+
in_channels=feature_size * 8,
|
| 345 |
+
out_channels=feature_size * 4,
|
| 346 |
+
kernel_size=3,
|
| 347 |
+
upsample_kernel_size=2,
|
| 348 |
+
norm_name=norm_name,
|
| 349 |
+
res_block=True,
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
self.decoder3 = UnetrUpBlock(
|
| 353 |
+
spatial_dims=spatial_dims,
|
| 354 |
+
in_channels=feature_size * 4,
|
| 355 |
+
out_channels=feature_size * 2,
|
| 356 |
+
kernel_size=3,
|
| 357 |
+
upsample_kernel_size=2,
|
| 358 |
+
norm_name=norm_name,
|
| 359 |
+
res_block=True,
|
| 360 |
+
)
|
| 361 |
+
self.decoder2 = UnetrUpBlock(
|
| 362 |
+
spatial_dims=spatial_dims,
|
| 363 |
+
in_channels=feature_size * 2,
|
| 364 |
+
out_channels=feature_size,
|
| 365 |
+
kernel_size=3,
|
| 366 |
+
upsample_kernel_size=2,
|
| 367 |
+
norm_name=norm_name,
|
| 368 |
+
res_block=True,
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
self.decoder1 = UnetrUpBlock(
|
| 372 |
+
spatial_dims=spatial_dims,
|
| 373 |
+
in_channels=feature_size,
|
| 374 |
+
out_channels=feature_size,
|
| 375 |
+
kernel_size=3,
|
| 376 |
+
upsample_kernel_size=2,
|
| 377 |
+
norm_name=norm_name,
|
| 378 |
+
res_block=True,
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
def forward(self, x_in):
|
| 382 |
+
enc0, enc1, enc2, enc3, enc4, dec4 = self.encoder(x_in)
|
| 383 |
+
|
| 384 |
+
dec3 = self.decoder5(dec4, enc4)
|
| 385 |
+
dec2 = self.decoder4(dec3, enc3)
|
| 386 |
+
dec1 = self.decoder3(dec2, enc2)
|
| 387 |
+
dec0 = self.decoder2(dec1, enc1)
|
| 388 |
+
out = self.decoder1(dec0, enc0)
|
| 389 |
+
# print(dec3.shape, dec2.shape, dec1.shape, dec0.shape, out.shape)
|
| 390 |
+
# torch.Size([6, 384, 4, 4, 4]) torch.Size([6, 192, 8, 8, 8]) torch.Size([6, 96, 16, 16, 16])
|
| 391 |
+
# torch.Size([6, 48, 32, 32, 32]) torch.Size([6, 48, 64, 64, 64])
|
| 392 |
+
|
| 393 |
+
if self.deep_supervision:
|
| 394 |
+
out_ls = [out, dec0, dec1, dec2, dec3]
|
| 395 |
+
else:
|
| 396 |
+
out_ls = [out]
|
| 397 |
+
|
| 398 |
+
if self.return_skips:
|
| 399 |
+
skips = [enc0, enc1, enc2, enc3, enc4, dec4]
|
| 400 |
+
else:
|
| 401 |
+
skips = [dec4]
|
| 402 |
+
|
| 403 |
+
return skips, out_ls
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
def window_partition(x, window_size):
|
| 407 |
+
"""window partition operation based on: "Liu et al.,
|
| 408 |
+
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
|
| 409 |
+
<https://arxiv.org/abs/2103.14030>"
|
| 410 |
+
https://github.com/microsoft/Swin-Transformer
|
| 411 |
+
Args:
|
| 412 |
+
x: input tensor.
|
| 413 |
+
window_size: local window size.
|
| 414 |
+
"""
|
| 415 |
+
x_shape = x.size()
|
| 416 |
+
if len(x_shape) == 5:
|
| 417 |
+
b, d, h, w, c = x_shape
|
| 418 |
+
x = x.view(
|
| 419 |
+
b,
|
| 420 |
+
d // window_size[0],
|
| 421 |
+
window_size[0],
|
| 422 |
+
h // window_size[1],
|
| 423 |
+
window_size[1],
|
| 424 |
+
w // window_size[2],
|
| 425 |
+
window_size[2],
|
| 426 |
+
c,
|
| 427 |
+
)
|
| 428 |
+
windows = (
|
| 429 |
+
x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size[0] * window_size[1] * window_size[2], c)
|
| 430 |
+
)
|
| 431 |
+
elif len(x_shape) == 4:
|
| 432 |
+
b, h, w, c = x.shape
|
| 433 |
+
x = x.view(b, h // window_size[0], window_size[0], w // window_size[1], window_size[1], c)
|
| 434 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0] * window_size[1], c)
|
| 435 |
+
return windows
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
def window_reverse(windows, window_size, dims):
|
| 439 |
+
"""window reverse operation based on: "Liu et al.,
|
| 440 |
+
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
|
| 441 |
+
<https://arxiv.org/abs/2103.14030>"
|
| 442 |
+
https://github.com/microsoft/Swin-Transformer
|
| 443 |
+
Args:
|
| 444 |
+
windows: windows tensor.
|
| 445 |
+
window_size: local window size.
|
| 446 |
+
dims: dimension values.
|
| 447 |
+
"""
|
| 448 |
+
if len(dims) == 4:
|
| 449 |
+
b, d, h, w = dims
|
| 450 |
+
x = windows.view(
|
| 451 |
+
b,
|
| 452 |
+
d // window_size[0],
|
| 453 |
+
h // window_size[1],
|
| 454 |
+
w // window_size[2],
|
| 455 |
+
window_size[0],
|
| 456 |
+
window_size[1],
|
| 457 |
+
window_size[2],
|
| 458 |
+
-1,
|
| 459 |
+
)
|
| 460 |
+
x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(b, d, h, w, -1)
|
| 461 |
+
|
| 462 |
+
elif len(dims) == 3:
|
| 463 |
+
b, h, w = dims
|
| 464 |
+
x = windows.view(b, h // window_size[0], w // window_size[0], window_size[0], window_size[1], -1)
|
| 465 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1)
|
| 466 |
+
return x
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
def get_window_size(x_size, window_size, shift_size=None):
|
| 470 |
+
"""Computing window size based on: "Liu et al.,
|
| 471 |
+
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
|
| 472 |
+
<https://arxiv.org/abs/2103.14030>"
|
| 473 |
+
https://github.com/microsoft/Swin-Transformer
|
| 474 |
+
Args:
|
| 475 |
+
x_size: input size.
|
| 476 |
+
window_size: local window size.
|
| 477 |
+
shift_size: window shifting size.
|
| 478 |
+
"""
|
| 479 |
+
|
| 480 |
+
use_window_size = list(window_size)
|
| 481 |
+
if shift_size is not None:
|
| 482 |
+
use_shift_size = list(shift_size)
|
| 483 |
+
for i in range(len(x_size)):
|
| 484 |
+
if x_size[i] <= window_size[i]:
|
| 485 |
+
use_window_size[i] = x_size[i]
|
| 486 |
+
if shift_size is not None:
|
| 487 |
+
use_shift_size[i] = 0
|
| 488 |
+
|
| 489 |
+
if shift_size is None:
|
| 490 |
+
return tuple(use_window_size)
|
| 491 |
+
else:
|
| 492 |
+
return tuple(use_window_size), tuple(use_shift_size)
|
| 493 |
+
|
| 494 |
+
|
| 495 |
+
class WindowAttention(nn.Module):
|
| 496 |
+
"""
|
| 497 |
+
Window based multi-head self attention module with relative position bias based on: "Liu et al.,
|
| 498 |
+
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
|
| 499 |
+
<https://arxiv.org/abs/2103.14030>"
|
| 500 |
+
https://github.com/microsoft/Swin-Transformer
|
| 501 |
+
"""
|
| 502 |
+
|
| 503 |
+
def __init__(
|
| 504 |
+
self,
|
| 505 |
+
dim: int,
|
| 506 |
+
num_heads: int,
|
| 507 |
+
window_size: Sequence[int],
|
| 508 |
+
qkv_bias: bool = False,
|
| 509 |
+
attn_drop: float = 0.0,
|
| 510 |
+
proj_drop: float = 0.0,
|
| 511 |
+
) -> None:
|
| 512 |
+
"""
|
| 513 |
+
Args:
|
| 514 |
+
dim: number of feature channels.
|
| 515 |
+
num_heads: number of attention heads.
|
| 516 |
+
window_size: local window size.
|
| 517 |
+
qkv_bias: add a learnable bias to query, key, value.
|
| 518 |
+
attn_drop: attention dropout rate.
|
| 519 |
+
proj_drop: dropout rate of output.
|
| 520 |
+
"""
|
| 521 |
+
|
| 522 |
+
super().__init__()
|
| 523 |
+
self.dim = dim
|
| 524 |
+
self.window_size = window_size
|
| 525 |
+
self.num_heads = num_heads
|
| 526 |
+
head_dim = dim // num_heads
|
| 527 |
+
self.scale = head_dim**-0.5
|
| 528 |
+
mesh_args = torch.meshgrid.__kwdefaults__
|
| 529 |
+
|
| 530 |
+
if len(self.window_size) == 3:
|
| 531 |
+
self.relative_position_bias_table = nn.Parameter(
|
| 532 |
+
torch.zeros(
|
| 533 |
+
(2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1),
|
| 534 |
+
num_heads,
|
| 535 |
+
)
|
| 536 |
+
)
|
| 537 |
+
coords_d = torch.arange(self.window_size[0])
|
| 538 |
+
coords_h = torch.arange(self.window_size[1])
|
| 539 |
+
coords_w = torch.arange(self.window_size[2])
|
| 540 |
+
if mesh_args is not None:
|
| 541 |
+
coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w, indexing="ij"))
|
| 542 |
+
else:
|
| 543 |
+
coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w))
|
| 544 |
+
coords_flatten = torch.flatten(coords, 1)
|
| 545 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
| 546 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
|
| 547 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1
|
| 548 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
| 549 |
+
relative_coords[:, :, 2] += self.window_size[2] - 1
|
| 550 |
+
relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1)
|
| 551 |
+
relative_coords[:, :, 1] *= 2 * self.window_size[2] - 1
|
| 552 |
+
elif len(self.window_size) == 2:
|
| 553 |
+
self.relative_position_bias_table = nn.Parameter(
|
| 554 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
|
| 555 |
+
)
|
| 556 |
+
coords_h = torch.arange(self.window_size[0])
|
| 557 |
+
coords_w = torch.arange(self.window_size[1])
|
| 558 |
+
if mesh_args is not None:
|
| 559 |
+
coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"))
|
| 560 |
+
else:
|
| 561 |
+
coords = torch.stack(torch.meshgrid(coords_h, coords_w))
|
| 562 |
+
coords_flatten = torch.flatten(coords, 1)
|
| 563 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
| 564 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
|
| 565 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1
|
| 566 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
| 567 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
| 568 |
+
|
| 569 |
+
relative_position_index = relative_coords.sum(-1)
|
| 570 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
| 571 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 572 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 573 |
+
self.proj = nn.Linear(dim, dim)
|
| 574 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 575 |
+
trunc_normal_(self.relative_position_bias_table, std=0.02)
|
| 576 |
+
self.softmax = nn.Softmax(dim=-1)
|
| 577 |
+
|
| 578 |
+
def forward(self, x, mask):
|
| 579 |
+
b, n, c = x.shape
|
| 580 |
+
qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 581 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 582 |
+
q = q * self.scale
|
| 583 |
+
attn = q @ k.transpose(-2, -1)
|
| 584 |
+
relative_position_bias = self.relative_position_bias_table[
|
| 585 |
+
self.relative_position_index.clone()[:n, :n].reshape(-1)
|
| 586 |
+
].reshape(n, n, -1)
|
| 587 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
|
| 588 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
| 589 |
+
if mask is not None:
|
| 590 |
+
nw = mask.shape[0]
|
| 591 |
+
attn = attn.view(b // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0)
|
| 592 |
+
attn = attn.view(-1, self.num_heads, n, n)
|
| 593 |
+
attn = self.softmax(attn)
|
| 594 |
+
else:
|
| 595 |
+
attn = self.softmax(attn)
|
| 596 |
+
|
| 597 |
+
attn = self.attn_drop(attn)
|
| 598 |
+
x = (attn @ v).transpose(1, 2).reshape(b, n, c)
|
| 599 |
+
x = self.proj(x)
|
| 600 |
+
x = self.proj_drop(x)
|
| 601 |
+
return x
|
| 602 |
+
|
| 603 |
+
|
| 604 |
+
class SwinTransformerBlock(nn.Module):
|
| 605 |
+
"""
|
| 606 |
+
Swin Transformer block based on: "Liu et al.,
|
| 607 |
+
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
|
| 608 |
+
<https://arxiv.org/abs/2103.14030>"
|
| 609 |
+
https://github.com/microsoft/Swin-Transformer
|
| 610 |
+
"""
|
| 611 |
+
|
| 612 |
+
def __init__(
|
| 613 |
+
self,
|
| 614 |
+
dim: int,
|
| 615 |
+
num_heads: int,
|
| 616 |
+
window_size: Sequence[int],
|
| 617 |
+
shift_size: Sequence[int],
|
| 618 |
+
mlp_ratio: float = 4.0,
|
| 619 |
+
qkv_bias: bool = True,
|
| 620 |
+
drop: float = 0.0,
|
| 621 |
+
attn_drop: float = 0.0,
|
| 622 |
+
drop_path: float = 0.0,
|
| 623 |
+
act_layer: str = "GELU",
|
| 624 |
+
norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore
|
| 625 |
+
use_checkpoint: bool = False,
|
| 626 |
+
) -> None:
|
| 627 |
+
"""
|
| 628 |
+
Args:
|
| 629 |
+
dim: number of feature channels.
|
| 630 |
+
num_heads: number of attention heads.
|
| 631 |
+
window_size: local window size.
|
| 632 |
+
shift_size: window shift size.
|
| 633 |
+
mlp_ratio: ratio of mlp hidden dim to embedding dim.
|
| 634 |
+
qkv_bias: add a learnable bias to query, key, value.
|
| 635 |
+
drop: dropout rate.
|
| 636 |
+
attn_drop: attention dropout rate.
|
| 637 |
+
drop_path: stochastic depth rate.
|
| 638 |
+
act_layer: activation layer.
|
| 639 |
+
norm_layer: normalization layer.
|
| 640 |
+
use_checkpoint: use gradient checkpointing for reduced memory usage.
|
| 641 |
+
"""
|
| 642 |
+
|
| 643 |
+
super().__init__()
|
| 644 |
+
self.dim = dim
|
| 645 |
+
self.num_heads = num_heads
|
| 646 |
+
self.window_size = window_size
|
| 647 |
+
self.shift_size = shift_size
|
| 648 |
+
self.mlp_ratio = mlp_ratio
|
| 649 |
+
self.use_checkpoint = use_checkpoint
|
| 650 |
+
self.norm1 = norm_layer(dim)
|
| 651 |
+
self.attn = WindowAttention(
|
| 652 |
+
dim,
|
| 653 |
+
window_size=self.window_size,
|
| 654 |
+
num_heads=num_heads,
|
| 655 |
+
qkv_bias=qkv_bias,
|
| 656 |
+
attn_drop=attn_drop,
|
| 657 |
+
proj_drop=drop,
|
| 658 |
+
)
|
| 659 |
+
|
| 660 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 661 |
+
self.norm2 = norm_layer(dim)
|
| 662 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 663 |
+
self.mlp = Mlp(hidden_size=dim, mlp_dim=mlp_hidden_dim, act=act_layer, dropout_rate=drop, dropout_mode="swin")
|
| 664 |
+
|
| 665 |
+
def forward_part1(self, x, mask_matrix):
|
| 666 |
+
x_shape = x.size()
|
| 667 |
+
x = self.norm1(x)
|
| 668 |
+
if len(x_shape) == 5:
|
| 669 |
+
b, d, h, w, c = x.shape
|
| 670 |
+
window_size, shift_size = get_window_size((d, h, w), self.window_size, self.shift_size)
|
| 671 |
+
pad_l = pad_t = pad_d0 = 0
|
| 672 |
+
pad_d1 = (window_size[0] - d % window_size[0]) % window_size[0]
|
| 673 |
+
pad_b = (window_size[1] - h % window_size[1]) % window_size[1]
|
| 674 |
+
pad_r = (window_size[2] - w % window_size[2]) % window_size[2]
|
| 675 |
+
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1))
|
| 676 |
+
_, dp, hp, wp, _ = x.shape
|
| 677 |
+
dims = [b, dp, hp, wp]
|
| 678 |
+
|
| 679 |
+
elif len(x_shape) == 4:
|
| 680 |
+
b, h, w, c = x.shape
|
| 681 |
+
window_size, shift_size = get_window_size((h, w), self.window_size, self.shift_size)
|
| 682 |
+
pad_l = pad_t = 0
|
| 683 |
+
pad_r = (window_size[0] - h % window_size[0]) % window_size[0]
|
| 684 |
+
pad_b = (window_size[1] - w % window_size[1]) % window_size[1]
|
| 685 |
+
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
| 686 |
+
_, hp, wp, _ = x.shape
|
| 687 |
+
dims = [b, hp, wp]
|
| 688 |
+
|
| 689 |
+
if any(i > 0 for i in shift_size):
|
| 690 |
+
if len(x_shape) == 5:
|
| 691 |
+
shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3))
|
| 692 |
+
elif len(x_shape) == 4:
|
| 693 |
+
shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2))
|
| 694 |
+
attn_mask = mask_matrix
|
| 695 |
+
else:
|
| 696 |
+
shifted_x = x
|
| 697 |
+
attn_mask = None
|
| 698 |
+
x_windows = window_partition(shifted_x, window_size)
|
| 699 |
+
attn_windows = self.attn(x_windows, mask=attn_mask)
|
| 700 |
+
attn_windows = attn_windows.view(-1, *(window_size + (c,)))
|
| 701 |
+
shifted_x = window_reverse(attn_windows, window_size, dims)
|
| 702 |
+
if any(i > 0 for i in shift_size):
|
| 703 |
+
if len(x_shape) == 5:
|
| 704 |
+
x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3))
|
| 705 |
+
elif len(x_shape) == 4:
|
| 706 |
+
x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1]), dims=(1, 2))
|
| 707 |
+
else:
|
| 708 |
+
x = shifted_x
|
| 709 |
+
|
| 710 |
+
if len(x_shape) == 5:
|
| 711 |
+
if pad_d1 > 0 or pad_r > 0 or pad_b > 0:
|
| 712 |
+
x = x[:, :d, :h, :w, :].contiguous()
|
| 713 |
+
elif len(x_shape) == 4:
|
| 714 |
+
if pad_r > 0 or pad_b > 0:
|
| 715 |
+
x = x[:, :h, :w, :].contiguous()
|
| 716 |
+
|
| 717 |
+
return x
|
| 718 |
+
|
| 719 |
+
def forward_part2(self, x):
|
| 720 |
+
return self.drop_path(self.mlp(self.norm2(x)))
|
| 721 |
+
|
| 722 |
+
def load_from(self, weights, n_block, layer):
|
| 723 |
+
root = f"module.{layer}.0.blocks.{n_block}."
|
| 724 |
+
block_names = [
|
| 725 |
+
"norm1.weight",
|
| 726 |
+
"norm1.bias",
|
| 727 |
+
"attn.relative_position_bias_table",
|
| 728 |
+
"attn.relative_position_index",
|
| 729 |
+
"attn.qkv.weight",
|
| 730 |
+
"attn.qkv.bias",
|
| 731 |
+
"attn.proj.weight",
|
| 732 |
+
"attn.proj.bias",
|
| 733 |
+
"norm2.weight",
|
| 734 |
+
"norm2.bias",
|
| 735 |
+
"mlp.fc1.weight",
|
| 736 |
+
"mlp.fc1.bias",
|
| 737 |
+
"mlp.fc2.weight",
|
| 738 |
+
"mlp.fc2.bias",
|
| 739 |
+
]
|
| 740 |
+
with torch.no_grad():
|
| 741 |
+
self.norm1.weight.copy_(weights["state_dict"][root + block_names[0]])
|
| 742 |
+
self.norm1.bias.copy_(weights["state_dict"][root + block_names[1]])
|
| 743 |
+
self.attn.relative_position_bias_table.copy_(weights["state_dict"][root + block_names[2]])
|
| 744 |
+
self.attn.relative_position_index.copy_(weights["state_dict"][root + block_names[3]])
|
| 745 |
+
self.attn.qkv.weight.copy_(weights["state_dict"][root + block_names[4]])
|
| 746 |
+
self.attn.qkv.bias.copy_(weights["state_dict"][root + block_names[5]])
|
| 747 |
+
self.attn.proj.weight.copy_(weights["state_dict"][root + block_names[6]])
|
| 748 |
+
self.attn.proj.bias.copy_(weights["state_dict"][root + block_names[7]])
|
| 749 |
+
self.norm2.weight.copy_(weights["state_dict"][root + block_names[8]])
|
| 750 |
+
self.norm2.bias.copy_(weights["state_dict"][root + block_names[9]])
|
| 751 |
+
self.mlp.linear1.weight.copy_(weights["state_dict"][root + block_names[10]])
|
| 752 |
+
self.mlp.linear1.bias.copy_(weights["state_dict"][root + block_names[11]])
|
| 753 |
+
self.mlp.linear2.weight.copy_(weights["state_dict"][root + block_names[12]])
|
| 754 |
+
self.mlp.linear2.bias.copy_(weights["state_dict"][root + block_names[13]])
|
| 755 |
+
|
| 756 |
+
def forward(self, x, mask_matrix):
|
| 757 |
+
shortcut = x
|
| 758 |
+
if self.use_checkpoint:
|
| 759 |
+
x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix)
|
| 760 |
+
else:
|
| 761 |
+
x = self.forward_part1(x, mask_matrix)
|
| 762 |
+
x = shortcut + self.drop_path(x)
|
| 763 |
+
if self.use_checkpoint:
|
| 764 |
+
x = x + checkpoint.checkpoint(self.forward_part2, x)
|
| 765 |
+
else:
|
| 766 |
+
x = x + self.forward_part2(x)
|
| 767 |
+
return x
|
| 768 |
+
|
| 769 |
+
|
| 770 |
+
class PatchMerging(nn.Module):
|
| 771 |
+
"""
|
| 772 |
+
Patch merging layer based on: "Liu et al.,
|
| 773 |
+
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
|
| 774 |
+
<https://arxiv.org/abs/2103.14030>"
|
| 775 |
+
https://github.com/microsoft/Swin-Transformer
|
| 776 |
+
"""
|
| 777 |
+
|
| 778 |
+
def __init__(
|
| 779 |
+
self, dim: int, norm_layer: Type[LayerNorm] = nn.LayerNorm, spatial_dims: int = 3
|
| 780 |
+
) -> None: # type: ignore
|
| 781 |
+
"""
|
| 782 |
+
Args:
|
| 783 |
+
dim: number of feature channels.
|
| 784 |
+
norm_layer: normalization layer.
|
| 785 |
+
spatial_dims: number of spatial dims.
|
| 786 |
+
"""
|
| 787 |
+
|
| 788 |
+
super().__init__()
|
| 789 |
+
self.dim = dim
|
| 790 |
+
if spatial_dims == 3:
|
| 791 |
+
self.reduction = nn.Linear(8 * dim, 2 * dim, bias=False)
|
| 792 |
+
self.norm = norm_layer(8 * dim)
|
| 793 |
+
elif spatial_dims == 2:
|
| 794 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
| 795 |
+
self.norm = norm_layer(4 * dim)
|
| 796 |
+
|
| 797 |
+
def forward(self, x):
|
| 798 |
+
|
| 799 |
+
x_shape = x.size()
|
| 800 |
+
if len(x_shape) == 5:
|
| 801 |
+
b, d, h, w, c = x_shape
|
| 802 |
+
pad_input = (h % 2 == 1) or (w % 2 == 1) or (d % 2 == 1)
|
| 803 |
+
if pad_input:
|
| 804 |
+
x = F.pad(x, (0, 0, 0, d % 2, 0, w % 2, 0, h % 2))
|
| 805 |
+
x0 = x[:, 0::2, 0::2, 0::2, :]
|
| 806 |
+
x1 = x[:, 1::2, 0::2, 0::2, :]
|
| 807 |
+
x2 = x[:, 0::2, 1::2, 0::2, :]
|
| 808 |
+
x3 = x[:, 0::2, 0::2, 1::2, :]
|
| 809 |
+
x4 = x[:, 1::2, 0::2, 1::2, :]
|
| 810 |
+
x5 = x[:, 0::2, 1::2, 0::2, :]
|
| 811 |
+
x6 = x[:, 0::2, 0::2, 1::2, :]
|
| 812 |
+
x7 = x[:, 1::2, 1::2, 1::2, :]
|
| 813 |
+
x = torch.cat([x0, x1, x2, x3, x4, x5, x6, x7], -1)
|
| 814 |
+
|
| 815 |
+
elif len(x_shape) == 4:
|
| 816 |
+
b, h, w, c = x_shape
|
| 817 |
+
pad_input = (h % 2 == 1) or (w % 2 == 1)
|
| 818 |
+
if pad_input:
|
| 819 |
+
x = F.pad(x, (0, 0, 0, w % 2, 0, h % 2))
|
| 820 |
+
x0 = x[:, 0::2, 0::2, :]
|
| 821 |
+
x1 = x[:, 1::2, 0::2, :]
|
| 822 |
+
x2 = x[:, 0::2, 1::2, :]
|
| 823 |
+
x3 = x[:, 1::2, 1::2, :]
|
| 824 |
+
x = torch.cat([x0, x1, x2, x3], -1)
|
| 825 |
+
|
| 826 |
+
x = self.norm(x)
|
| 827 |
+
x = self.reduction(x)
|
| 828 |
+
return x
|
| 829 |
+
|
| 830 |
+
|
| 831 |
+
def compute_mask(dims, window_size, shift_size, device):
|
| 832 |
+
"""Computing region masks based on: "Liu et al.,
|
| 833 |
+
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
|
| 834 |
+
<https://arxiv.org/abs/2103.14030>"
|
| 835 |
+
https://github.com/microsoft/Swin-Transformer
|
| 836 |
+
Args:
|
| 837 |
+
dims: dimension values.
|
| 838 |
+
window_size: local window size.
|
| 839 |
+
shift_size: shift size.
|
| 840 |
+
device: device.
|
| 841 |
+
"""
|
| 842 |
+
|
| 843 |
+
cnt = 0
|
| 844 |
+
|
| 845 |
+
if len(dims) == 3:
|
| 846 |
+
d, h, w = dims
|
| 847 |
+
img_mask = torch.zeros((1, d, h, w, 1), device=device)
|
| 848 |
+
for d in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0], None):
|
| 849 |
+
for h in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1], None):
|
| 850 |
+
for w in slice(-window_size[2]), slice(-window_size[2], -shift_size[2]), slice(-shift_size[2], None):
|
| 851 |
+
img_mask[:, d, h, w, :] = cnt
|
| 852 |
+
cnt += 1
|
| 853 |
+
|
| 854 |
+
elif len(dims) == 2:
|
| 855 |
+
h, w = dims
|
| 856 |
+
img_mask = torch.zeros((1, h, w, 1), device=device)
|
| 857 |
+
for h in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0], None):
|
| 858 |
+
for w in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1], None):
|
| 859 |
+
img_mask[:, h, w, :] = cnt
|
| 860 |
+
cnt += 1
|
| 861 |
+
|
| 862 |
+
mask_windows = window_partition(img_mask, window_size)
|
| 863 |
+
mask_windows = mask_windows.squeeze(-1)
|
| 864 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
| 865 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
| 866 |
+
|
| 867 |
+
return attn_mask
|
| 868 |
+
|
| 869 |
+
|
| 870 |
+
class BasicLayer(nn.Module):
|
| 871 |
+
"""
|
| 872 |
+
Basic Swin Transformer layer in one stage based on: "Liu et al.,
|
| 873 |
+
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
|
| 874 |
+
<https://arxiv.org/abs/2103.14030>"
|
| 875 |
+
https://github.com/microsoft/Swin-Transformer
|
| 876 |
+
"""
|
| 877 |
+
|
| 878 |
+
def __init__(
|
| 879 |
+
self,
|
| 880 |
+
dim: int,
|
| 881 |
+
depth: int,
|
| 882 |
+
num_heads: int,
|
| 883 |
+
window_size: Sequence[int],
|
| 884 |
+
drop_path: list,
|
| 885 |
+
mlp_ratio: float = 4.0,
|
| 886 |
+
qkv_bias: bool = False,
|
| 887 |
+
drop: float = 0.0,
|
| 888 |
+
attn_drop: float = 0.0,
|
| 889 |
+
norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore
|
| 890 |
+
downsample: isinstance = None, # type: ignore
|
| 891 |
+
use_checkpoint: bool = False,
|
| 892 |
+
) -> None:
|
| 893 |
+
"""
|
| 894 |
+
Args:
|
| 895 |
+
dim: number of feature channels.
|
| 896 |
+
depths: number of layers in each stage.
|
| 897 |
+
num_heads: number of attention heads.
|
| 898 |
+
window_size: local window size.
|
| 899 |
+
drop_path: stochastic depth rate.
|
| 900 |
+
mlp_ratio: ratio of mlp hidden dim to embedding dim.
|
| 901 |
+
qkv_bias: add a learnable bias to query, key, value.
|
| 902 |
+
drop: dropout rate.
|
| 903 |
+
attn_drop: attention dropout rate.
|
| 904 |
+
norm_layer: normalization layer.
|
| 905 |
+
downsample: downsample layer at the end of the layer.
|
| 906 |
+
use_checkpoint: use gradient checkpointing for reduced memory usage.
|
| 907 |
+
"""
|
| 908 |
+
|
| 909 |
+
super().__init__()
|
| 910 |
+
self.window_size = window_size
|
| 911 |
+
self.shift_size = tuple(i // 2 for i in window_size)
|
| 912 |
+
self.no_shift = tuple(0 for i in window_size)
|
| 913 |
+
self.depth = depth
|
| 914 |
+
self.use_checkpoint = use_checkpoint
|
| 915 |
+
self.blocks = nn.ModuleList(
|
| 916 |
+
[
|
| 917 |
+
SwinTransformerBlock(
|
| 918 |
+
dim=dim,
|
| 919 |
+
num_heads=num_heads,
|
| 920 |
+
window_size=self.window_size,
|
| 921 |
+
shift_size=self.no_shift if (i % 2 == 0) else self.shift_size,
|
| 922 |
+
mlp_ratio=mlp_ratio,
|
| 923 |
+
qkv_bias=qkv_bias,
|
| 924 |
+
drop=drop,
|
| 925 |
+
attn_drop=attn_drop,
|
| 926 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
| 927 |
+
norm_layer=norm_layer,
|
| 928 |
+
use_checkpoint=use_checkpoint,
|
| 929 |
+
)
|
| 930 |
+
for i in range(depth)
|
| 931 |
+
]
|
| 932 |
+
)
|
| 933 |
+
self.downsample = downsample
|
| 934 |
+
if self.downsample is not None:
|
| 935 |
+
self.downsample = downsample(dim=dim, norm_layer=norm_layer, spatial_dims=len(self.window_size))
|
| 936 |
+
|
| 937 |
+
def forward(self, x):
|
| 938 |
+
x_shape = x.size()
|
| 939 |
+
if len(x_shape) == 5:
|
| 940 |
+
b, c, d, h, w = x_shape
|
| 941 |
+
window_size, shift_size = get_window_size((d, h, w), self.window_size, self.shift_size)
|
| 942 |
+
x = rearrange(x, "b c d h w -> b d h w c")
|
| 943 |
+
dp = int(np.ceil(d / window_size[0])) * window_size[0]
|
| 944 |
+
hp = int(np.ceil(h / window_size[1])) * window_size[1]
|
| 945 |
+
wp = int(np.ceil(w / window_size[2])) * window_size[2]
|
| 946 |
+
attn_mask = compute_mask([dp, hp, wp], window_size, shift_size, x.device)
|
| 947 |
+
for blk in self.blocks:
|
| 948 |
+
x = blk(x, attn_mask)
|
| 949 |
+
x = x.view(b, d, h, w, -1)
|
| 950 |
+
if self.downsample is not None:
|
| 951 |
+
x = self.downsample(x)
|
| 952 |
+
x = rearrange(x, "b d h w c -> b c d h w")
|
| 953 |
+
|
| 954 |
+
elif len(x_shape) == 4:
|
| 955 |
+
b, c, h, w = x_shape
|
| 956 |
+
window_size, shift_size = get_window_size((h, w), self.window_size, self.shift_size)
|
| 957 |
+
x = rearrange(x, "b c h w -> b h w c")
|
| 958 |
+
hp = int(np.ceil(h / window_size[0])) * window_size[0]
|
| 959 |
+
wp = int(np.ceil(w / window_size[1])) * window_size[1]
|
| 960 |
+
attn_mask = compute_mask([hp, wp], window_size, shift_size, x.device)
|
| 961 |
+
for blk in self.blocks:
|
| 962 |
+
x = blk(x, attn_mask)
|
| 963 |
+
x = x.view(b, h, w, -1)
|
| 964 |
+
if self.downsample is not None:
|
| 965 |
+
x = self.downsample(x)
|
| 966 |
+
x = rearrange(x, "b h w c -> b c h w")
|
| 967 |
+
return x
|
| 968 |
+
|
| 969 |
+
|
| 970 |
+
class SwinTransformer(nn.Module):
|
| 971 |
+
"""
|
| 972 |
+
Swin Transformer based on: "Liu et al.,
|
| 973 |
+
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
|
| 974 |
+
<https://arxiv.org/abs/2103.14030>"
|
| 975 |
+
https://github.com/microsoft/Swin-Transformer
|
| 976 |
+
"""
|
| 977 |
+
|
| 978 |
+
def __init__(
|
| 979 |
+
self,
|
| 980 |
+
in_chans: int,
|
| 981 |
+
embed_dim: int,
|
| 982 |
+
window_size: Sequence[int],
|
| 983 |
+
patch_size: Sequence[int],
|
| 984 |
+
depths: Sequence[int],
|
| 985 |
+
num_heads: Sequence[int],
|
| 986 |
+
mlp_ratio: float = 4.0,
|
| 987 |
+
qkv_bias: bool = True,
|
| 988 |
+
drop_rate: float = 0.0,
|
| 989 |
+
attn_drop_rate: float = 0.0,
|
| 990 |
+
drop_path_rate: float = 0.0,
|
| 991 |
+
norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore
|
| 992 |
+
patch_norm: bool = False,
|
| 993 |
+
use_checkpoint: bool = False,
|
| 994 |
+
spatial_dims: int = 3,
|
| 995 |
+
) -> None:
|
| 996 |
+
"""
|
| 997 |
+
Args:
|
| 998 |
+
in_chans: dimension of input channels.
|
| 999 |
+
embed_dim: number of linear projection output channels.
|
| 1000 |
+
window_size: local window size.
|
| 1001 |
+
patch_size: patch size.
|
| 1002 |
+
depths: number of layers in each stage.
|
| 1003 |
+
num_heads: number of attention heads.
|
| 1004 |
+
mlp_ratio: ratio of mlp hidden dim to embedding dim.
|
| 1005 |
+
qkv_bias: add a learnable bias to query, key, value.
|
| 1006 |
+
drop_rate: dropout rate.
|
| 1007 |
+
attn_drop_rate: attention dropout rate.
|
| 1008 |
+
drop_path_rate: stochastic depth rate.
|
| 1009 |
+
norm_layer: normalization layer.
|
| 1010 |
+
patch_norm: add normalization after patch embedding.
|
| 1011 |
+
use_checkpoint: use gradient checkpointing for reduced memory usage.
|
| 1012 |
+
spatial_dims: spatial dimension.
|
| 1013 |
+
"""
|
| 1014 |
+
|
| 1015 |
+
super().__init__()
|
| 1016 |
+
self.num_layers = len(depths)
|
| 1017 |
+
self.embed_dim = embed_dim
|
| 1018 |
+
self.patch_norm = patch_norm
|
| 1019 |
+
self.window_size = window_size
|
| 1020 |
+
self.patch_size = patch_size
|
| 1021 |
+
self.patch_embed = PatchEmbed(
|
| 1022 |
+
patch_size=self.patch_size,
|
| 1023 |
+
in_chans=in_chans,
|
| 1024 |
+
embed_dim=embed_dim,
|
| 1025 |
+
norm_layer=norm_layer if self.patch_norm else None, # type: ignore
|
| 1026 |
+
spatial_dims=spatial_dims,
|
| 1027 |
+
)
|
| 1028 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
| 1029 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
| 1030 |
+
self.layers1 = nn.ModuleList()
|
| 1031 |
+
self.layers2 = nn.ModuleList()
|
| 1032 |
+
self.layers3 = nn.ModuleList()
|
| 1033 |
+
self.layers4 = nn.ModuleList()
|
| 1034 |
+
for i_layer in range(self.num_layers):
|
| 1035 |
+
layer = BasicLayer(
|
| 1036 |
+
dim=int(embed_dim * 2**i_layer),
|
| 1037 |
+
depth=depths[i_layer],
|
| 1038 |
+
num_heads=num_heads[i_layer],
|
| 1039 |
+
window_size=self.window_size,
|
| 1040 |
+
drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
|
| 1041 |
+
mlp_ratio=mlp_ratio,
|
| 1042 |
+
qkv_bias=qkv_bias,
|
| 1043 |
+
drop=drop_rate,
|
| 1044 |
+
attn_drop=attn_drop_rate,
|
| 1045 |
+
norm_layer=norm_layer,
|
| 1046 |
+
downsample=PatchMerging,
|
| 1047 |
+
use_checkpoint=use_checkpoint,
|
| 1048 |
+
)
|
| 1049 |
+
if i_layer == 0:
|
| 1050 |
+
self.layers1.append(layer)
|
| 1051 |
+
elif i_layer == 1:
|
| 1052 |
+
self.layers2.append(layer)
|
| 1053 |
+
elif i_layer == 2:
|
| 1054 |
+
self.layers3.append(layer)
|
| 1055 |
+
elif i_layer == 3:
|
| 1056 |
+
self.layers4.append(layer)
|
| 1057 |
+
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
|
| 1058 |
+
|
| 1059 |
+
def proj_out(self, x, normalize=False):
|
| 1060 |
+
if normalize:
|
| 1061 |
+
x_shape = x.size()
|
| 1062 |
+
if len(x_shape) == 5:
|
| 1063 |
+
n, ch, d, h, w = x_shape
|
| 1064 |
+
x = rearrange(x, "n c d h w -> n d h w c")
|
| 1065 |
+
x = F.layer_norm(x, [ch])
|
| 1066 |
+
x = rearrange(x, "n d h w c -> n c d h w")
|
| 1067 |
+
elif len(x_shape) == 4:
|
| 1068 |
+
n, ch, h, w = x_shape
|
| 1069 |
+
x = rearrange(x, "n c h w -> n h w c")
|
| 1070 |
+
x = F.layer_norm(x, [ch])
|
| 1071 |
+
x = rearrange(x, "n h w c -> n c h w")
|
| 1072 |
+
return x
|
| 1073 |
+
|
| 1074 |
+
def forward(self, x, normalize=True):
|
| 1075 |
+
x0 = self.patch_embed(x)
|
| 1076 |
+
x0 = self.pos_drop(x0)
|
| 1077 |
+
x0_out = self.proj_out(x0, normalize)
|
| 1078 |
+
x1 = self.layers1[0](x0.contiguous())
|
| 1079 |
+
x1_out = self.proj_out(x1, normalize)
|
| 1080 |
+
x2 = self.layers2[0](x1.contiguous())
|
| 1081 |
+
x2_out = self.proj_out(x2, normalize)
|
| 1082 |
+
x3 = self.layers3[0](x2.contiguous())
|
| 1083 |
+
x3_out = self.proj_out(x3, normalize)
|
| 1084 |
+
x4 = self.layers4[0](x3.contiguous())
|
| 1085 |
+
x4_out = self.proj_out(x4, normalize)
|
| 1086 |
+
return [x0_out, x1_out, x2_out, x3_out, x4_out]
|
| 1087 |
+
|
| 1088 |
+
if __name__ == '__main__':
|
| 1089 |
+
import os
|
| 1090 |
+
def get_parameter_number(model):
|
| 1091 |
+
total_num = sum(p.numel() for p in model.parameters())
|
| 1092 |
+
trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 1093 |
+
return {'Total': total_num, 'Trainable': trainable_num}
|
| 1094 |
+
|
| 1095 |
+
model = SwinUNETR(
|
| 1096 |
+
img_size=[288, 288, 96], # the real input should satisfy : d,h,w > 32
|
| 1097 |
+
in_channels=3,
|
| 1098 |
+
feature_size=48,
|
| 1099 |
+
drop_rate=0.0,
|
| 1100 |
+
attn_drop_rate=0.0,
|
| 1101 |
+
dropout_path_rate=0.0,
|
| 1102 |
+
use_checkpoint=False,
|
| 1103 |
+
deep_supervision=True,
|
| 1104 |
+
return_skips=True,
|
| 1105 |
+
).cuda()
|
| 1106 |
+
|
| 1107 |
+
if is_master():
|
| 1108 |
+
print(f"** UNET ** {get_parameter_number(model)['Total']/1e6}M parameters")
|
| 1109 |
+
|
| 1110 |
+
image = torch.rand((1, 3, 288, 288, 96)).cuda()
|
| 1111 |
+
skips, outs = model(image)
|
| 1112 |
+
|
| 1113 |
+
for s in skips:
|
| 1114 |
+
print(s.shape)
|
| 1115 |
+
for out in outs:
|
| 1116 |
+
print(out.shape)
|
model/__init__.py
ADDED
|
File without changes
|
model/base_bert.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
from transformers import BertModel, AutoTokenizer
|
| 5 |
+
|
| 6 |
+
class BaseBERT(nn.Module):
|
| 7 |
+
def __init__(self, basebert_checkpoint='bert-base-uncased'):
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.tokenizer = AutoTokenizer.from_pretrained(basebert_checkpoint)
|
| 10 |
+
self.model = BertModel.from_pretrained(basebert_checkpoint)
|
| 11 |
+
self.modality_embed = nn.Embedding(4, 768)
|
| 12 |
+
|
| 13 |
+
def forward(self, text, modality):
|
| 14 |
+
encoded = self.tokenizer(
|
| 15 |
+
text,
|
| 16 |
+
truncation=True,
|
| 17 |
+
padding=True,
|
| 18 |
+
return_tensors='pt',
|
| 19 |
+
max_length=64,
|
| 20 |
+
).to(device=torch.cuda.current_device())
|
| 21 |
+
|
| 22 |
+
text_feature = self.model(**encoded).last_hidden_state[:, 0, :]
|
| 23 |
+
modality_feature = self.modality_embed(modality)
|
| 24 |
+
text_feature += modality_feature
|
| 25 |
+
|
| 26 |
+
return text_feature
|
model/build_model.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import time
|
| 4 |
+
import os
|
| 5 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
from .maskformer import Maskformer
|
| 10 |
+
|
| 11 |
+
from train.dist import is_master
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_parameter_number(model):
|
| 15 |
+
total_num = sum(p.numel() for p in model.parameters())
|
| 16 |
+
trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 17 |
+
return {'Total': total_num, 'Trainable': trainable_num}
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def build_maskformer(args, device, gpu_id):
|
| 21 |
+
model = Maskformer(args.vision_backbone, args.input_channels, args.crop_size, args.patch_size, args.deep_supervision)
|
| 22 |
+
|
| 23 |
+
model = model.to(device)
|
| 24 |
+
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
| 25 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu_id], find_unused_parameters=True)
|
| 26 |
+
|
| 27 |
+
def get_parameter_number(model):
|
| 28 |
+
total_num = sum(p.numel() for p in model.parameters())
|
| 29 |
+
trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 30 |
+
return {'Total': total_num, 'Trainable': trainable_num}
|
| 31 |
+
|
| 32 |
+
if is_master():
|
| 33 |
+
print(f"** MODEL ** {get_parameter_number(model)['Total']/1e6}M parameters")
|
| 34 |
+
|
| 35 |
+
return model
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def load_checkpoint(checkpoint_file,
|
| 39 |
+
resume,
|
| 40 |
+
partial_load,
|
| 41 |
+
model,
|
| 42 |
+
device,
|
| 43 |
+
optimizer=None,
|
| 44 |
+
):
|
| 45 |
+
|
| 46 |
+
if is_master():
|
| 47 |
+
print('** CHECKPOINT ** : Load checkpoint from %s' % (checkpoint_file))
|
| 48 |
+
|
| 49 |
+
checkpoint = torch.load(checkpoint_file, map_location=device)
|
| 50 |
+
|
| 51 |
+
# load part of the checkpoint
|
| 52 |
+
if partial_load:
|
| 53 |
+
model_dict = model.state_dict()
|
| 54 |
+
# check difference
|
| 55 |
+
unexpected_state_dict = [k for k in checkpoint['model_state_dict'].keys() if k not in model_dict.keys()]
|
| 56 |
+
missing_state_dict = [k for k in model_dict.keys() if k not in checkpoint['model_state_dict'].keys()]
|
| 57 |
+
unmatchd_state_dict = [k for k,v in checkpoint['model_state_dict'].items() if k in model_dict.keys() and v.shape != model_dict[k].shape]
|
| 58 |
+
# load partial parameters
|
| 59 |
+
state_dict = {k:v for k,v in checkpoint['model_state_dict'].items() if k in model_dict.keys() and v.shape == model_dict[k].shape}
|
| 60 |
+
model_dict.update(state_dict)
|
| 61 |
+
model.load_state_dict(model_dict)
|
| 62 |
+
if is_master():
|
| 63 |
+
print('The following parameters are unexpected in SAT checkpoint:\n', unexpected_state_dict)
|
| 64 |
+
print('The following parameters are missing in SAT checkpoint:\n', missing_state_dict)
|
| 65 |
+
print('The following parameters have different shapes in SAT checkpoint:\n', unmatchd_state_dict)
|
| 66 |
+
print('The following parameters are loaded in SAT:\n', state_dict.keys())
|
| 67 |
+
else:
|
| 68 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 69 |
+
|
| 70 |
+
# if resume, load optimizer and step
|
| 71 |
+
if resume:
|
| 72 |
+
try:
|
| 73 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 74 |
+
except:
|
| 75 |
+
print('Optimizer state dict not matched, skip loading optimizer state dict')
|
| 76 |
+
pass
|
| 77 |
+
start_step = int(checkpoint['step']) + 1
|
| 78 |
+
print('Resume from step %d' % (start_step))
|
| 79 |
+
else:
|
| 80 |
+
start_step = 1
|
| 81 |
+
|
| 82 |
+
return model, optimizer, start_step
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def inherit_knowledge_encoder(knowledge_encoder_checkpoint,
|
| 86 |
+
model,
|
| 87 |
+
device
|
| 88 |
+
):
|
| 89 |
+
# inherit unet encoder and multiscale feature projection layer from knowledge encoder
|
| 90 |
+
checkpoint = torch.load(knowledge_encoder_checkpoint, map_location=device)
|
| 91 |
+
|
| 92 |
+
model_dict = model.state_dict()
|
| 93 |
+
visual_encoder_state_dict = {k.replace('atlas_tower', 'backbone'):v for k,v in checkpoint['model_state_dict'].items() if 'atlas_tower.encoder' in k} # encoder部分
|
| 94 |
+
model_dict.update(visual_encoder_state_dict)
|
| 95 |
+
proj_state_dict = {k.replace('atlas_tower.', ''):v for k,v in checkpoint['model_state_dict'].items() if 'atlas_tower.projection_layer' in k} # projection layer部分
|
| 96 |
+
model_dict.update(proj_state_dict)
|
| 97 |
+
model.load_state_dict(model_dict)
|
| 98 |
+
|
| 99 |
+
if is_master():
|
| 100 |
+
print('** CHECKPOINT ** : Inherit pretrained unet encoder from %s' % (knowledge_encoder_checkpoint))
|
| 101 |
+
print('The following parameters are loaded in SAT:\n', list(visual_encoder_state_dict.keys())+list(proj_state_dict.keys()))
|
| 102 |
+
|
| 103 |
+
return model
|
model/dynamic-network-architectures-main/.gitignore
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
env/
|
| 12 |
+
build/
|
| 13 |
+
develop-eggs/
|
| 14 |
+
dist/
|
| 15 |
+
downloads/
|
| 16 |
+
eggs/
|
| 17 |
+
.eggs/
|
| 18 |
+
lib/
|
| 19 |
+
lib64/
|
| 20 |
+
parts/
|
| 21 |
+
sdist/
|
| 22 |
+
var/
|
| 23 |
+
*.egg-info/
|
| 24 |
+
.installed.cfg
|
| 25 |
+
*.egg
|
| 26 |
+
|
| 27 |
+
# PyInstaller
|
| 28 |
+
# Usually these files are written by a python script from a template
|
| 29 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 30 |
+
*.manifest
|
| 31 |
+
*.spec
|
| 32 |
+
|
| 33 |
+
# Installer logs
|
| 34 |
+
pip-log.txt
|
| 35 |
+
pip-delete-this-directory.txt
|
| 36 |
+
|
| 37 |
+
# Unit test / coverage reports
|
| 38 |
+
htmlcov/
|
| 39 |
+
.tox/
|
| 40 |
+
.coverage
|
| 41 |
+
.coverage.*
|
| 42 |
+
.cache
|
| 43 |
+
nosetests.xml
|
| 44 |
+
coverage.xml
|
| 45 |
+
*,cover
|
| 46 |
+
.hypothesis/
|
| 47 |
+
|
| 48 |
+
# Translations
|
| 49 |
+
*.mo
|
| 50 |
+
*.pot
|
| 51 |
+
|
| 52 |
+
# Django stuff:
|
| 53 |
+
*.log
|
| 54 |
+
local_settings.py
|
| 55 |
+
|
| 56 |
+
# Flask stuff:
|
| 57 |
+
instance/
|
| 58 |
+
.webassets-cache
|
| 59 |
+
|
| 60 |
+
# Scrapy stuff:
|
| 61 |
+
.scrapy
|
| 62 |
+
|
| 63 |
+
# Sphinx documentation
|
| 64 |
+
docs/_build/
|
| 65 |
+
|
| 66 |
+
# PyBuilder
|
| 67 |
+
target/
|
| 68 |
+
|
| 69 |
+
# IPython Notebook
|
| 70 |
+
.ipynb_checkpoints
|
| 71 |
+
|
| 72 |
+
# pyenv
|
| 73 |
+
.python-version
|
| 74 |
+
|
| 75 |
+
# celery beat schedule file
|
| 76 |
+
celerybeat-schedule
|
| 77 |
+
|
| 78 |
+
# dotenv
|
| 79 |
+
.env
|
| 80 |
+
|
| 81 |
+
# virtualenv
|
| 82 |
+
venv/
|
| 83 |
+
ENV/
|
| 84 |
+
|
| 85 |
+
# Spyder project settings
|
| 86 |
+
.spyderproject
|
| 87 |
+
|
| 88 |
+
# Rope project settings
|
| 89 |
+
.ropeproject
|
| 90 |
+
|
| 91 |
+
*.memmap
|
| 92 |
+
*.zip
|
| 93 |
+
*.npz
|
| 94 |
+
*.npy
|
| 95 |
+
*.jpg
|
| 96 |
+
*.jpeg
|
| 97 |
+
.idea
|
| 98 |
+
*.txt
|
| 99 |
+
.idea/*
|
| 100 |
+
*.nii.gz
|
| 101 |
+
*.nii
|
| 102 |
+
*.tif
|
| 103 |
+
*.bmp
|
| 104 |
+
*.pkl
|
| 105 |
+
*.xml
|
| 106 |
+
*.pkl
|
| 107 |
+
*.pdf
|
| 108 |
+
*.jpg
|
| 109 |
+
*.jpeg
|
| 110 |
+
|
| 111 |
+
*.model
|
| 112 |
+
|
| 113 |
+
cifar_lightning/mlruns*
|
model/dynamic-network-architectures-main/LICENCE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [2022] [Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
model/dynamic-network-architectures-main/README.md
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Dynamic Network Architectures
|
| 2 |
+
|
| 3 |
+
This repository contains several ResNet, U-Net and VGG architectures in pytorch that can be dynamically adapted to a varying number of image dimensions (1D, 2D or 3D) and the number of input channels.
|
| 4 |
+
|
| 5 |
+
## Available models
|
| 6 |
+
### ResNet
|
| 7 |
+
We implement the standard [ResNetD](https://arxiv.org/pdf/1812.01187.pdf) 18, 34, 50 and 152. For ResNets 50 and 152 also bottleneck implementations are available. Moreover, adapted versions that are better suited for smaller image sizes such as CIFAR can be used.
|
| 8 |
+
|
| 9 |
+
All models additionally include regularization techniques like [Stochastic Depth](https://arxiv.org/pdf/1603.09382.pdf), [Squeeze & Excitation](https://arxiv.org/pdf/1709.01507.pdf) and [Final Layer Dropout](https://jmlr.org/papers/volume15/srivastava14a/srivastava14a.pdf).
|
| 10 |
+
|
| 11 |
+
### VGG
|
| 12 |
+
In contrast to the original [VGG](https://arxiv.org/pdf/1409.1556.pdf) implementation we exclude the final fully-connected layers in the end and replace it by additional convolutional layers and only one fully-connected layer in the end. Adapted versions that are better suited for smaller image sizes such as CIFAR can be used.
|
| 13 |
+
|
| 14 |
+
### U-Net
|
| 15 |
+
For the [U-Net](https://arxiv.org/pdf/1505.04597.pdf) a plain convolutional encoder as well as a residual encoder are available.
|
| 16 |
+
|
| 17 |
+
# Acknowledgements
|
| 18 |
+
|
| 19 |
+
<p align="left">
|
| 20 |
+
<img src="imgs/Logos/HI_Logo.png" width="150">
|
| 21 |
+
<img src="imgs/Logos/DKFZ_Logo.png" width="500">
|
| 22 |
+
</p>
|
| 23 |
+
|
| 24 |
+
This Repository is developed and maintained by the Applied Computer Vision Lab (ACVL)
|
| 25 |
+
of [Helmholtz Imaging](https://www.helmholtz-imaging.de/).
|
model/dynamic-network-architectures-main/dynamic_network_architectures.egg-info/PKG-INFO
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.4
|
| 2 |
+
Name: dynamic_network_architectures
|
| 3 |
+
Version: 0.2
|
| 4 |
+
Summary: none
|
| 5 |
+
Author: Fabian Isensee
|
| 6 |
+
Author-email: f.isensee@dkfz.de
|
| 7 |
+
License: private
|
| 8 |
+
License-File: LICENCE
|
| 9 |
+
Requires-Dist: torch>=1.6.0a
|
| 10 |
+
Requires-Dist: numpy
|
| 11 |
+
Dynamic: author
|
| 12 |
+
Dynamic: author-email
|
| 13 |
+
Dynamic: license
|
| 14 |
+
Dynamic: license-file
|
| 15 |
+
Dynamic: requires-dist
|
| 16 |
+
Dynamic: summary
|
model/dynamic-network-architectures-main/dynamic_network_architectures.egg-info/SOURCES.txt
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
LICENCE
|
| 2 |
+
README.md
|
| 3 |
+
setup.py
|
| 4 |
+
dynamic_network_architectures/__init__.py
|
| 5 |
+
dynamic_network_architectures.egg-info/PKG-INFO
|
| 6 |
+
dynamic_network_architectures.egg-info/SOURCES.txt
|
| 7 |
+
dynamic_network_architectures.egg-info/dependency_links.txt
|
| 8 |
+
dynamic_network_architectures.egg-info/not-zip-safe
|
| 9 |
+
dynamic_network_architectures.egg-info/requires.txt
|
| 10 |
+
dynamic_network_architectures.egg-info/top_level.txt
|
| 11 |
+
dynamic_network_architectures/architectures/__init__.py
|
| 12 |
+
dynamic_network_architectures/architectures/resnet.py
|
| 13 |
+
dynamic_network_architectures/architectures/unet.py
|
| 14 |
+
dynamic_network_architectures/architectures/vgg.py
|
| 15 |
+
dynamic_network_architectures/building_blocks/__init__.py
|
| 16 |
+
dynamic_network_architectures/building_blocks/helper.py
|
| 17 |
+
dynamic_network_architectures/building_blocks/plain_conv_encoder.py
|
| 18 |
+
dynamic_network_architectures/building_blocks/regularization.py
|
| 19 |
+
dynamic_network_architectures/building_blocks/residual.py
|
| 20 |
+
dynamic_network_architectures/building_blocks/residual_encoders.py
|
| 21 |
+
dynamic_network_architectures/building_blocks/simple_conv_blocks.py
|
| 22 |
+
dynamic_network_architectures/building_blocks/unet_decoder.py
|
| 23 |
+
dynamic_network_architectures/initialization/__init__.py
|
| 24 |
+
dynamic_network_architectures/initialization/weight_init.py
|
model/dynamic-network-architectures-main/dynamic_network_architectures.egg-info/dependency_links.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
model/dynamic-network-architectures-main/dynamic_network_architectures.egg-info/not-zip-safe
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
model/dynamic-network-architectures-main/dynamic_network_architectures.egg-info/requires.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=1.6.0a
|
| 2 |
+
numpy
|
model/dynamic-network-architectures-main/dynamic_network_architectures.egg-info/top_level.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
dynamic_network_architectures
|
model/dynamic-network-architectures-main/dynamic_network_architectures/__init__.py
ADDED
|
File without changes
|
model/dynamic-network-architectures-main/dynamic_network_architectures/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (256 Bytes). View file
|
|
|
model/dynamic-network-architectures-main/dynamic_network_architectures/architectures/__init__.py
ADDED
|
File without changes
|
model/dynamic-network-architectures-main/dynamic_network_architectures/architectures/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (270 Bytes). View file
|
|
|
model/dynamic-network-architectures-main/dynamic_network_architectures/architectures/__pycache__/unet.cpython-310.pyc
ADDED
|
Binary file (7.52 kB). View file
|
|
|
model/dynamic-network-architectures-main/dynamic_network_architectures/architectures/resnet.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from dynamic_network_architectures.building_blocks.residual_encoders import ResidualEncoder, BottleneckD, BasicBlockD
|
| 3 |
+
from dynamic_network_architectures.building_blocks.helper import get_matching_pool_op, get_default_network_config
|
| 4 |
+
from dynamic_network_architectures.building_blocks.simple_conv_blocks import ConvDropoutNormReLU
|
| 5 |
+
from torch import nn
|
| 6 |
+
|
| 7 |
+
_ResNet_CONFIGS = {
|
| 8 |
+
'18': {'features_per_stage': (64, 128, 256, 512), 'n_blocks_per_stage': (2, 2, 2, 2), 'strides': (1, 2, 2, 2),
|
| 9 |
+
'block': BasicBlockD, 'bottleneck_channels': None, 'disable_default_stem': True, 'stem_channels': None},
|
| 10 |
+
'34': {'features_per_stage': (64, 128, 256, 512), 'n_blocks_per_stage': (3, 4, 6, 3), 'strides': (1, 2, 2, 2),
|
| 11 |
+
'block': BasicBlockD, 'bottleneck_channels': None, 'disable_default_stem': True, 'stem_channels': None},
|
| 12 |
+
'50': {'features_per_stage': (64, 128, 256, 512), 'n_blocks_per_stage': (4, 6, 10, 5), 'strides': (1, 2, 2, 2),
|
| 13 |
+
'block': BasicBlockD, 'bottleneck_channels': None, 'disable_default_stem': True, 'stem_channels': None},
|
| 14 |
+
'152': {'features_per_stage': (64, 128, 256, 512), 'n_blocks_per_stage': (4, 13, 55, 4), 'strides': (1, 2, 2, 2),
|
| 15 |
+
'block': BasicBlockD, 'bottleneck_channels': None, 'disable_default_stem': True, 'stem_channels': None},
|
| 16 |
+
'50_bn': {'features_per_stage': (256, 512, 1024, 2048), 'n_blocks_per_stage': (3, 4, 6, 3), 'strides': (1, 2, 2, 2),
|
| 17 |
+
'block': BottleneckD, 'bottleneck_channels': (64, 128, 256, 512), 'disable_default_stem': True,
|
| 18 |
+
'stem_channels': 64},
|
| 19 |
+
'152_bn': {'features_per_stage': (256, 512, 1024, 2048), 'n_blocks_per_stage': (3, 8, 36, 3),
|
| 20 |
+
'strides': (1, 2, 2, 2),
|
| 21 |
+
'block': BottleneckD, 'bottleneck_channels': (64, 128, 256, 512), 'disable_default_stem': True,
|
| 22 |
+
'stem_channels': 64},
|
| 23 |
+
'18_cifar': {'features_per_stage': (64, 128, 256, 512), 'n_blocks_per_stage': (2, 2, 2, 2), 'strides': (1, 2, 2, 2),
|
| 24 |
+
'block': BasicBlockD, 'bottleneck_channels': None, 'disable_default_stem': False,
|
| 25 |
+
'stem_channels': None},
|
| 26 |
+
'34_cifar': {'features_per_stage': (64, 128, 256, 512), 'n_blocks_per_stage': (3, 4, 6, 3), 'strides': (1, 2, 2, 2),
|
| 27 |
+
'block': BasicBlockD, 'bottleneck_channels': None, 'disable_default_stem': False,
|
| 28 |
+
'stem_channels': None},
|
| 29 |
+
'50_cifar': {'features_per_stage': (64, 128, 256, 512), 'n_blocks_per_stage': (4, 6, 10, 5),
|
| 30 |
+
'strides': (1, 2, 2, 2),
|
| 31 |
+
'block': BasicBlockD, 'bottleneck_channels': None, 'disable_default_stem': False,
|
| 32 |
+
'stem_channels': None},
|
| 33 |
+
'152_cifar': {'features_per_stage': (64, 128, 256, 512), 'n_blocks_per_stage': (4, 13, 55, 4),
|
| 34 |
+
'strides': (1, 2, 2, 2),
|
| 35 |
+
'block': BasicBlockD, 'bottleneck_channels': None, 'disable_default_stem': False,
|
| 36 |
+
'stem_channels': None},
|
| 37 |
+
'50_cifar_bn': {'features_per_stage': (256, 512, 1024, 2048), 'n_blocks_per_stage': (3, 4, 6, 3),
|
| 38 |
+
'strides': (1, 2, 2, 2),
|
| 39 |
+
'block': BottleneckD, 'bottleneck_channels': (64, 128, 256, 512), 'disable_default_stem': False,
|
| 40 |
+
'stem_channels': 64},
|
| 41 |
+
'152_cifar_bn': {'features_per_stage': (256, 512, 1024, 2048), 'n_blocks_per_stage': (3, 8, 36, 3),
|
| 42 |
+
'strides': (1, 2, 2, 2),
|
| 43 |
+
'block': BottleneckD, 'bottleneck_channels': (64, 128, 256, 512), 'disable_default_stem': False,
|
| 44 |
+
'stem_channels': 64},
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class ResNetD(nn.Module):
|
| 49 |
+
def __init__(self, n_classes: int, n_input_channel: int = 3, config='18', input_dimension=2,
|
| 50 |
+
final_layer_dropout=0.0, stochastic_depth_p=0.0, squeeze_excitation=False,
|
| 51 |
+
squeeze_excitation_rd_ratio=1./16):
|
| 52 |
+
"""
|
| 53 |
+
Implements ResNetD (https://arxiv.org/pdf/1812.01187.pdf).
|
| 54 |
+
Args:
|
| 55 |
+
n_classes: Number of classes
|
| 56 |
+
n_input_channel: Number of input channels (e.g. 3 for RGB)
|
| 57 |
+
config: Configuration of the ResNet
|
| 58 |
+
input_dimension: Number of dimensions of the data (1, 2 or 3)
|
| 59 |
+
final_layer_dropout: Probability of dropout before the final classifier
|
| 60 |
+
stochastic_depth_p: Stochastic Depth probability
|
| 61 |
+
squeeze_excitation: Whether Squeeze and Excitation should be applied
|
| 62 |
+
squeeze_excitation_rd_ratio: Squeeze and Excitation Reduction Ratio
|
| 63 |
+
Returns:
|
| 64 |
+
ResNet Model
|
| 65 |
+
"""
|
| 66 |
+
super().__init__()
|
| 67 |
+
self.input_channels = n_input_channel
|
| 68 |
+
self.cfg = _ResNet_CONFIGS[config]
|
| 69 |
+
self.ops = get_default_network_config(dimension=input_dimension)
|
| 70 |
+
self.final_layer_dropout_p = final_layer_dropout
|
| 71 |
+
|
| 72 |
+
if self.cfg['disable_default_stem']:
|
| 73 |
+
stem_features = self.cfg['stem_channels'] if self.cfg['stem_channels'] is not None else \
|
| 74 |
+
self.cfg['features_per_stage'][0]
|
| 75 |
+
self.stem = self._build_imagenet_stem_D(stem_features)
|
| 76 |
+
encoder_input_features = stem_features
|
| 77 |
+
else:
|
| 78 |
+
encoder_input_features = n_input_channel
|
| 79 |
+
self.stem = None
|
| 80 |
+
|
| 81 |
+
self.encoder = ResidualEncoder(encoder_input_features, n_stages=len(self.cfg['features_per_stage']),
|
| 82 |
+
features_per_stage=self.cfg['features_per_stage'], conv_op=self.ops['conv_op'],
|
| 83 |
+
kernel_sizes=3, strides=self.cfg['strides'],
|
| 84 |
+
n_blocks_per_stage=self.cfg['n_blocks_per_stage'], conv_bias=False,
|
| 85 |
+
norm_op=self.ops['norm_op'], norm_op_kwargs=None, dropout_op=None,
|
| 86 |
+
dropout_op_kwargs=None, nonlin=nn.ReLU,
|
| 87 |
+
nonlin_kwargs={'inplace': True}, block=self.cfg['block'],
|
| 88 |
+
bottleneck_channels=self.cfg['bottleneck_channels'], return_skips=False,
|
| 89 |
+
disable_default_stem=self.cfg['disable_default_stem'],
|
| 90 |
+
stem_channels=self.cfg['stem_channels'],
|
| 91 |
+
stochastic_depth_p=stochastic_depth_p,
|
| 92 |
+
squeeze_excitation=squeeze_excitation,
|
| 93 |
+
squeeze_excitation_reduction_ratio=squeeze_excitation_rd_ratio)
|
| 94 |
+
|
| 95 |
+
self.gap = get_matching_pool_op(conv_op=self.ops['conv_op'], adaptive=True, pool_type='avg')(1)
|
| 96 |
+
self.classifier = nn.Linear(self.cfg['features_per_stage'][-1], n_classes, True)
|
| 97 |
+
self.final_layer_dropout = self.ops['dropout_op'](p=self.final_layer_dropout_p)
|
| 98 |
+
|
| 99 |
+
def forward(self, x):
|
| 100 |
+
if self.stem is not None:
|
| 101 |
+
x = self.stem(x)
|
| 102 |
+
x = self.encoder(x)
|
| 103 |
+
x = self.gap(x)
|
| 104 |
+
x = self.final_layer_dropout(x).squeeze()
|
| 105 |
+
|
| 106 |
+
return self.classifier(x)
|
| 107 |
+
|
| 108 |
+
def _build_imagenet_stem_D(self, stem_features):
|
| 109 |
+
"""
|
| 110 |
+
https://arxiv.org/pdf/1812.01187.pdf
|
| 111 |
+
|
| 112 |
+
use 3 3x3(x3) convs instead of one 7x7. Stride is located in first conv.
|
| 113 |
+
|
| 114 |
+
Fig2 b) describes this
|
| 115 |
+
:return:
|
| 116 |
+
"""
|
| 117 |
+
c1 = ConvDropoutNormReLU(self.ops['conv_op'], self.input_channels, stem_features, 3, 2, False,
|
| 118 |
+
self.ops['norm_op'], None, None, None, nn.ReLU, {'inplace': True})
|
| 119 |
+
c2 = ConvDropoutNormReLU(self.ops['conv_op'], stem_features, stem_features, 3, 1, False,
|
| 120 |
+
self.ops['norm_op'], None, None, None, nn.ReLU, {'inplace': True})
|
| 121 |
+
c3 = ConvDropoutNormReLU(self.ops['conv_op'], stem_features, stem_features, 3, 1, False,
|
| 122 |
+
self.ops['norm_op'], None, None, None, nn.ReLU, {'inplace': True})
|
| 123 |
+
pl = get_matching_pool_op(conv_op=self.ops['conv_op'], adaptive=False, pool_type='max')(2)
|
| 124 |
+
stem = nn.Sequential(c1, c2, c3, pl)
|
| 125 |
+
return stem
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class ResNet18_CIFAR(ResNetD):
|
| 129 |
+
def __init__(self, n_classes: int, n_input_channels: int = 3, input_dimension: int = 2,
|
| 130 |
+
final_layer_dropout: float = 0.0, stochastic_depth_p: float = 0.0, squeeze_excitation: bool = False,
|
| 131 |
+
squeeze_excitation_rd_ratio: float = 1./16):
|
| 132 |
+
super().__init__(n_classes, n_input_channels, config='18_cifar', input_dimension=input_dimension,
|
| 133 |
+
final_layer_dropout=final_layer_dropout, stochastic_depth_p=stochastic_depth_p,
|
| 134 |
+
squeeze_excitation=squeeze_excitation, squeeze_excitation_rd_ratio=squeeze_excitation_rd_ratio)
|
| 135 |
+
|
| 136 |
+
class ResNet34_CIFAR(ResNetD):
|
| 137 |
+
def __init__(self, n_classes: int, n_input_channels: int = 3, input_dimension: int = 2,
|
| 138 |
+
final_layer_dropout: float = 0.0, stochastic_depth_p: float = 0.0, squeeze_excitation: bool = False,
|
| 139 |
+
squeeze_excitation_rd_ratio: float = 1./16):
|
| 140 |
+
super().__init__(n_classes, n_input_channels, config='34_cifar', input_dimension=input_dimension,
|
| 141 |
+
final_layer_dropout=final_layer_dropout, stochastic_depth_p=stochastic_depth_p,
|
| 142 |
+
squeeze_excitation=squeeze_excitation, squeeze_excitation_rd_ratio=squeeze_excitation_rd_ratio)
|
| 143 |
+
|
| 144 |
+
class ResNet50_CIFAR(ResNetD):
|
| 145 |
+
def __init__(self, n_classes: int, n_input_channels: int = 3, input_dimension: int = 2,
|
| 146 |
+
final_layer_dropout: float = 0.0, stochastic_depth_p: float = 0.0, squeeze_excitation: bool = False,
|
| 147 |
+
squeeze_excitation_rd_ratio: float = 1./16):
|
| 148 |
+
super().__init__(n_classes, n_input_channels, config='50_cifar', input_dimension=input_dimension,
|
| 149 |
+
final_layer_dropout=final_layer_dropout, stochastic_depth_p=stochastic_depth_p,
|
| 150 |
+
squeeze_excitation=squeeze_excitation, squeeze_excitation_rd_ratio=squeeze_excitation_rd_ratio)
|
| 151 |
+
|
| 152 |
+
class ResNet152_CIFAR(ResNetD):
|
| 153 |
+
def __init__(self, n_classes: int, n_input_channels: int = 3, input_dimension: int = 2,
|
| 154 |
+
final_layer_dropout: float = 0.0, stochastic_depth_p: float = 0.0, squeeze_excitation: bool = False,
|
| 155 |
+
squeeze_excitation_rd_ratio: float = 1./16):
|
| 156 |
+
super().__init__(n_classes, n_input_channels, config='152_cifar', input_dimension=input_dimension,
|
| 157 |
+
final_layer_dropout=final_layer_dropout, stochastic_depth_p=stochastic_depth_p,
|
| 158 |
+
squeeze_excitation=squeeze_excitation, squeeze_excitation_rd_ratio=squeeze_excitation_rd_ratio)
|
| 159 |
+
|
| 160 |
+
class ResNet50bn_CIFAR(ResNetD):
|
| 161 |
+
def __init__(self, n_classes: int, n_input_channels: int = 3, input_dimension: int = 2,
|
| 162 |
+
final_layer_dropout: float = 0.0, stochastic_depth_p: float = 0.0, squeeze_excitation: bool = False,
|
| 163 |
+
squeeze_excitation_rd_ratio: float = 1./16):
|
| 164 |
+
super().__init__(n_classes, n_input_channels, config='50_cifar_bn', input_dimension=input_dimension,
|
| 165 |
+
final_layer_dropout=final_layer_dropout, stochastic_depth_p=stochastic_depth_p,
|
| 166 |
+
squeeze_excitation=squeeze_excitation, squeeze_excitation_rd_ratio=squeeze_excitation_rd_ratio)
|
| 167 |
+
|
| 168 |
+
class ResNet152bn_CIFAR(ResNetD):
|
| 169 |
+
def __init__(self, n_classes: int, n_input_channels: int = 3, input_dimension: int = 2,
|
| 170 |
+
final_layer_dropout: float = 0.0, stochastic_depth_p: float = 0.0, squeeze_excitation: bool = False,
|
| 171 |
+
squeeze_excitation_rd_ratio: float = 1./16):
|
| 172 |
+
super().__init__(n_classes, n_input_channels, config='152_cifar_bn', input_dimension=input_dimension,
|
| 173 |
+
final_layer_dropout=final_layer_dropout, stochastic_depth_p=stochastic_depth_p,
|
| 174 |
+
squeeze_excitation=squeeze_excitation, squeeze_excitation_rd_ratio=squeeze_excitation_rd_ratio)
|
| 175 |
+
|
| 176 |
+
class ResNet18(ResNetD):
|
| 177 |
+
def __init__(self, n_classes: int, n_input_channels: int = 3, input_dimension: int = 2,
|
| 178 |
+
final_layer_dropout: float = 0.0, stochastic_depth_p: float = 0.0, squeeze_excitation: bool = False,
|
| 179 |
+
squeeze_excitation_rd_ratio: float = 1./16):
|
| 180 |
+
super().__init__(n_classes, n_input_channels, config='18', input_dimension=input_dimension,
|
| 181 |
+
final_layer_dropout=final_layer_dropout, stochastic_depth_p=stochastic_depth_p,
|
| 182 |
+
squeeze_excitation=squeeze_excitation, squeeze_excitation_rd_ratio=squeeze_excitation_rd_ratio)
|
| 183 |
+
|
| 184 |
+
class ResNet34(ResNetD):
|
| 185 |
+
def __init__(self, n_classes: int, n_input_channels: int = 3, input_dimension: int = 2,
|
| 186 |
+
final_layer_dropout: float = 0.0, stochastic_depth_p: float = 0.0, squeeze_excitation: bool = False,
|
| 187 |
+
squeeze_excitation_rd_ratio: float = 1./16):
|
| 188 |
+
super().__init__(n_classes, n_input_channels, config='34', input_dimension=input_dimension,
|
| 189 |
+
final_layer_dropout=final_layer_dropout, stochastic_depth_p=stochastic_depth_p,
|
| 190 |
+
squeeze_excitation=squeeze_excitation, squeeze_excitation_rd_ratio=squeeze_excitation_rd_ratio)
|
| 191 |
+
|
| 192 |
+
class ResNet50(ResNetD):
|
| 193 |
+
def __init__(self, n_classes: int, n_input_channels: int = 3, input_dimension: int = 2,
|
| 194 |
+
final_layer_dropout: float = 0.0, stochastic_depth_p: float = 0.0, squeeze_excitation: bool = False,
|
| 195 |
+
squeeze_excitation_rd_ratio: float = 1./16):
|
| 196 |
+
super().__init__(n_classes, n_input_channels, config='50', input_dimension=input_dimension,
|
| 197 |
+
final_layer_dropout=final_layer_dropout, stochastic_depth_p=stochastic_depth_p,
|
| 198 |
+
squeeze_excitation=squeeze_excitation, squeeze_excitation_rd_ratio=squeeze_excitation_rd_ratio)
|
| 199 |
+
|
| 200 |
+
class ResNet152(ResNetD):
|
| 201 |
+
def __init__(self, n_classes: int, n_input_channels: int = 3, input_dimension: int = 2,
|
| 202 |
+
final_layer_dropout: float = 0.0, stochastic_depth_p: float = 0.0, squeeze_excitation: bool = False,
|
| 203 |
+
squeeze_excitation_rd_ratio: float = 1./16):
|
| 204 |
+
super().__init__(n_classes, n_input_channels, config='152', input_dimension=input_dimension,
|
| 205 |
+
final_layer_dropout=final_layer_dropout, stochastic_depth_p=stochastic_depth_p,
|
| 206 |
+
squeeze_excitation=squeeze_excitation, squeeze_excitation_rd_ratio=squeeze_excitation_rd_ratio)
|
| 207 |
+
|
| 208 |
+
class ResNet50bn(ResNetD):
|
| 209 |
+
def __init__(self, n_classes: int, n_input_channels: int = 3, input_dimension: int = 2,
|
| 210 |
+
final_layer_dropout: float = 0.0, stochastic_depth_p: float = 0.0, squeeze_excitation: bool = False,
|
| 211 |
+
squeeze_excitation_rd_ratio: float = 1./16):
|
| 212 |
+
super().__init__(n_classes, n_input_channels, config='50_bn', input_dimension=input_dimension,
|
| 213 |
+
final_layer_dropout=final_layer_dropout, stochastic_depth_p=stochastic_depth_p,
|
| 214 |
+
squeeze_excitation=squeeze_excitation, squeeze_excitation_rd_ratio=squeeze_excitation_rd_ratio)
|
| 215 |
+
|
| 216 |
+
class ResNet152bn(ResNetD):
|
| 217 |
+
def __init__(self, n_classes: int, n_input_channels: int = 3, input_dimension: int = 2,
|
| 218 |
+
final_layer_dropout: float = 0.0, stochastic_depth_p: float = 0.0, squeeze_excitation: bool = False,
|
| 219 |
+
squeeze_excitation_rd_ratio: float = 1./16):
|
| 220 |
+
super().__init__(n_classes, n_input_channels, config='152_bn', input_dimension=input_dimension,
|
| 221 |
+
final_layer_dropout=final_layer_dropout, stochastic_depth_p=stochastic_depth_p,
|
| 222 |
+
squeeze_excitation=squeeze_excitation, squeeze_excitation_rd_ratio=squeeze_excitation_rd_ratio)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
if __name__ == '__main__':
|
| 226 |
+
data = torch.rand((1, 3, 224, 224))
|
| 227 |
+
|
| 228 |
+
model = ResNet50bn(10, 3)
|
| 229 |
+
import hiddenlayer as hl
|
| 230 |
+
|
| 231 |
+
g = hl.build_graph(model, data,
|
| 232 |
+
transforms=None)
|
| 233 |
+
g.save("network_architecture.pdf")
|
| 234 |
+
del g
|
| 235 |
+
|
| 236 |
+
#print(model.compute_conv_feature_map_size((32, 32)))
|
model/dynamic-network-architectures-main/dynamic_network_architectures/architectures/unet.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Union, Type, List, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from dynamic_network_architectures.building_blocks.residual_encoders import ResidualEncoder
|
| 5 |
+
from dynamic_network_architectures.building_blocks.residual import BasicBlockD, BottleneckD
|
| 6 |
+
from torch import nn
|
| 7 |
+
from torch.nn.modules.conv import _ConvNd
|
| 8 |
+
from torch.nn.modules.dropout import _DropoutNd
|
| 9 |
+
|
| 10 |
+
from dynamic_network_architectures.building_blocks.plain_conv_encoder import PlainConvEncoder
|
| 11 |
+
from dynamic_network_architectures.building_blocks.unet_decoder import UNetDecoder, UNetDecoder_Seg
|
| 12 |
+
from dynamic_network_architectures.building_blocks.helper import convert_conv_op_to_dim
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class PlainConvUNet(nn.Module):
|
| 16 |
+
def __init__(self,
|
| 17 |
+
input_channels: int,
|
| 18 |
+
n_stages: int,
|
| 19 |
+
features_per_stage: Union[int, List[int], Tuple[int, ...]],
|
| 20 |
+
conv_op: Type[_ConvNd],
|
| 21 |
+
kernel_sizes: Union[int, List[int], Tuple[int, ...]],
|
| 22 |
+
strides: Union[int, List[int], Tuple[int, ...]],
|
| 23 |
+
n_conv_per_stage: Union[int, List[int], Tuple[int, ...]],
|
| 24 |
+
n_conv_per_stage_decoder: Union[int, Tuple[int, ...], List[int]],
|
| 25 |
+
conv_bias: bool = False,
|
| 26 |
+
norm_op: Union[None, Type[nn.Module]] = None,
|
| 27 |
+
norm_op_kwargs: dict = None,
|
| 28 |
+
dropout_op: Union[None, Type[_DropoutNd]] = None,
|
| 29 |
+
dropout_op_kwargs: dict = None,
|
| 30 |
+
nonlin: Union[None, Type[torch.nn.Module]] = None, # activation
|
| 31 |
+
nonlin_kwargs: dict = None,
|
| 32 |
+
deep_supervision: bool = False,
|
| 33 |
+
nonlin_first: bool = False
|
| 34 |
+
):
|
| 35 |
+
"""
|
| 36 |
+
nonlin_first: if True you get conv -> nonlin -> norm. Else it's conv -> norm -> nonlin
|
| 37 |
+
"""
|
| 38 |
+
super().__init__()
|
| 39 |
+
if isinstance(n_conv_per_stage, int):
|
| 40 |
+
n_conv_per_stage = [n_conv_per_stage] * n_stages
|
| 41 |
+
if isinstance(n_conv_per_stage_decoder, int):
|
| 42 |
+
n_conv_per_stage_decoder = [n_conv_per_stage_decoder] * (n_stages - 1)
|
| 43 |
+
assert len(n_conv_per_stage) == n_stages, "n_conv_per_stage must have as many entries as we have " \
|
| 44 |
+
f"resolution stages. here: {n_stages}. " \
|
| 45 |
+
f"n_conv_per_stage: {n_conv_per_stage}"
|
| 46 |
+
assert len(n_conv_per_stage_decoder) == (n_stages - 1), "n_conv_per_stage_decoder must have one less entries " \
|
| 47 |
+
f"as we have resolution stages. here: {n_stages} " \
|
| 48 |
+
f"stages, so it should have {n_stages - 1} entries. " \
|
| 49 |
+
f"n_conv_per_stage_decoder: {n_conv_per_stage_decoder}"
|
| 50 |
+
self.encoder = PlainConvEncoder(input_channels, n_stages, features_per_stage, conv_op, kernel_sizes, strides,
|
| 51 |
+
n_conv_per_stage, conv_bias, norm_op, norm_op_kwargs, dropout_op,
|
| 52 |
+
dropout_op_kwargs, nonlin, nonlin_kwargs, return_skips=True,
|
| 53 |
+
nonlin_first=nonlin_first)
|
| 54 |
+
|
| 55 |
+
self.decoder = UNetDecoder(self.encoder, n_conv_per_stage_decoder, deep_supervision,
|
| 56 |
+
nonlin_first=nonlin_first)
|
| 57 |
+
|
| 58 |
+
def forward(self, x):
|
| 59 |
+
skips = self.encoder(x) # [2, 32, 256, 256, 96] ... [2, 768, 8, 8, 3]
|
| 60 |
+
outs = self.decoder(skips) # [2, 32, 256, 256, 96] ... [2, 512, 16, 16, 6]
|
| 61 |
+
return skips, outs # latent_embeddings(a list of multiscale features), perpixel_embeddings(a list of decoder outputs)
|
| 62 |
+
|
| 63 |
+
def compute_conv_feature_map_size(self, input_size):
|
| 64 |
+
assert len(input_size) == convert_conv_op_to_dim(self.encoder.conv_op), "just give the image size without color/feature channels or " \
|
| 65 |
+
"batch channel. Do not give input_size=(b, c, x, y(, z)). " \
|
| 66 |
+
"Give input_size=(x, y(, z))!"
|
| 67 |
+
return self.encoder.compute_conv_feature_map_size(input_size) + self.decoder.compute_conv_feature_map_size(input_size)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class PlainConvUNet_Seg(nn.Module):
|
| 71 |
+
def __init__(self,
|
| 72 |
+
input_channels: int,
|
| 73 |
+
n_stages: int,
|
| 74 |
+
features_per_stage: Union[int, List[int], Tuple[int, ...]],
|
| 75 |
+
conv_op: Type[_ConvNd],
|
| 76 |
+
kernel_sizes: Union[int, List[int], Tuple[int, ...]],
|
| 77 |
+
strides: Union[int, List[int], Tuple[int, ...]],
|
| 78 |
+
n_conv_per_stage: Union[int, List[int], Tuple[int, ...]],
|
| 79 |
+
num_classes: int,
|
| 80 |
+
n_conv_per_stage_decoder: Union[int, Tuple[int, ...], List[int]],
|
| 81 |
+
conv_bias: bool = False,
|
| 82 |
+
norm_op: Union[None, Type[nn.Module]] = None,
|
| 83 |
+
norm_op_kwargs: dict = None,
|
| 84 |
+
dropout_op: Union[None, Type[_DropoutNd]] = None,
|
| 85 |
+
dropout_op_kwargs: dict = None,
|
| 86 |
+
nonlin: Union[None, Type[torch.nn.Module]] = None, # activation
|
| 87 |
+
nonlin_kwargs: dict = None,
|
| 88 |
+
deep_supervision: bool = False,
|
| 89 |
+
nonlin_first: bool = False
|
| 90 |
+
):
|
| 91 |
+
"""
|
| 92 |
+
nonlin_first: if True you get conv -> nonlin -> norm. Else it's conv -> norm -> nonlin
|
| 93 |
+
"""
|
| 94 |
+
super().__init__()
|
| 95 |
+
if isinstance(n_conv_per_stage, int):
|
| 96 |
+
n_conv_per_stage = [n_conv_per_stage] * n_stages
|
| 97 |
+
if isinstance(n_conv_per_stage_decoder, int):
|
| 98 |
+
n_conv_per_stage_decoder = [n_conv_per_stage_decoder] * (n_stages - 1)
|
| 99 |
+
assert len(n_conv_per_stage) == n_stages, "n_conv_per_stage must have as many entries as we have " \
|
| 100 |
+
f"resolution stages. here: {n_stages}. " \
|
| 101 |
+
f"n_conv_per_stage: {n_conv_per_stage}"
|
| 102 |
+
assert len(n_conv_per_stage_decoder) == (n_stages - 1), "n_conv_per_stage_decoder must have one less entries " \
|
| 103 |
+
f"as we have resolution stages. here: {n_stages} " \
|
| 104 |
+
f"stages, so it should have {n_stages - 1} entries. " \
|
| 105 |
+
f"n_conv_per_stage_decoder: {n_conv_per_stage_decoder}"
|
| 106 |
+
self.encoder = PlainConvEncoder(input_channels, n_stages, features_per_stage, conv_op, kernel_sizes, strides,
|
| 107 |
+
n_conv_per_stage, conv_bias, norm_op, norm_op_kwargs, dropout_op,
|
| 108 |
+
dropout_op_kwargs, nonlin, nonlin_kwargs, return_skips=True,
|
| 109 |
+
nonlin_first=nonlin_first)
|
| 110 |
+
self.decoder = UNetDecoder_Seg(self.encoder, num_classes, n_conv_per_stage_decoder, deep_supervision,
|
| 111 |
+
nonlin_first=nonlin_first)
|
| 112 |
+
|
| 113 |
+
def forward(self, x):
|
| 114 |
+
skips = self.encoder(x) # [2, 32, 256, 256, 96] ... [2, 768, 8, 8, 3]
|
| 115 |
+
out = self.decoder(skips) # [2, num_class, 256, 256, 96]
|
| 116 |
+
return out
|
| 117 |
+
|
| 118 |
+
def compute_conv_feature_map_size(self, input_size):
|
| 119 |
+
assert len(input_size) == convert_conv_op_to_dim(self.encoder.conv_op), "just give the image size without color/feature channels or " \
|
| 120 |
+
"batch channel. Do not give input_size=(b, c, x, y(, z)). " \
|
| 121 |
+
"Give input_size=(x, y(, z))!"
|
| 122 |
+
return self.encoder.compute_conv_feature_map_size(input_size) + self.decoder.compute_conv_feature_map_size(input_size)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class ResidualEncoderUNet(nn.Module):
|
| 126 |
+
def __init__(self,
|
| 127 |
+
input_channels: int,
|
| 128 |
+
n_stages: int,
|
| 129 |
+
features_per_stage: Union[int, List[int], Tuple[int, ...]],
|
| 130 |
+
conv_op: Type[_ConvNd],
|
| 131 |
+
kernel_sizes: Union[int, List[int], Tuple[int, ...]],
|
| 132 |
+
strides: Union[int, List[int], Tuple[int, ...]],
|
| 133 |
+
n_blocks_per_stage: Union[int, List[int], Tuple[int, ...]],
|
| 134 |
+
n_conv_per_stage_decoder: Union[int, Tuple[int, ...], List[int]],
|
| 135 |
+
conv_bias: bool = False,
|
| 136 |
+
norm_op: Union[None, Type[nn.Module]] = None,
|
| 137 |
+
norm_op_kwargs: dict = None,
|
| 138 |
+
dropout_op: Union[None, Type[_DropoutNd]] = None,
|
| 139 |
+
dropout_op_kwargs: dict = None,
|
| 140 |
+
nonlin: Union[None, Type[torch.nn.Module]] = None,
|
| 141 |
+
nonlin_kwargs: dict = None,
|
| 142 |
+
deep_supervision: bool = False,
|
| 143 |
+
block: Union[Type[BasicBlockD], Type[BottleneckD]] = BasicBlockD,
|
| 144 |
+
bottleneck_channels: Union[int, List[int], Tuple[int, ...]] = None,
|
| 145 |
+
stem_channels: int = None
|
| 146 |
+
):
|
| 147 |
+
super().__init__()
|
| 148 |
+
if isinstance(n_blocks_per_stage, int):
|
| 149 |
+
n_blocks_per_stage = [n_blocks_per_stage] * n_stages
|
| 150 |
+
if isinstance(n_conv_per_stage_decoder, int):
|
| 151 |
+
n_conv_per_stage_decoder = [n_conv_per_stage_decoder] * (n_stages - 1)
|
| 152 |
+
assert len(n_blocks_per_stage) == n_stages, "n_blocks_per_stage must have as many entries as we have " \
|
| 153 |
+
f"resolution stages. here: {n_stages}. " \
|
| 154 |
+
f"n_blocks_per_stage: {n_blocks_per_stage}"
|
| 155 |
+
assert len(n_conv_per_stage_decoder) == (n_stages - 1), "n_conv_per_stage_decoder must have one less entries " \
|
| 156 |
+
f"as we have resolution stages. here: {n_stages} " \
|
| 157 |
+
f"stages, so it should have {n_stages - 1} entries. " \
|
| 158 |
+
f"n_conv_per_stage_decoder: {n_conv_per_stage_decoder}"
|
| 159 |
+
self.encoder = ResidualEncoder(input_channels, n_stages, features_per_stage, conv_op, kernel_sizes, strides,
|
| 160 |
+
n_blocks_per_stage, conv_bias, norm_op, norm_op_kwargs, dropout_op,
|
| 161 |
+
dropout_op_kwargs, nonlin, nonlin_kwargs, block, bottleneck_channels,
|
| 162 |
+
return_skips=True, disable_default_stem=False, stem_channels=stem_channels)
|
| 163 |
+
|
| 164 |
+
self.decoder = UNetDecoder(self.encoder, n_conv_per_stage_decoder, deep_supervision)
|
| 165 |
+
|
| 166 |
+
def forward(self, x):
|
| 167 |
+
skips = self.encoder(x) # [2, 32, 256, 256, 96] ... [2, 768, 8, 8, 3]
|
| 168 |
+
outs = self.decoder(skips) # [2, 32, 256, 256, 96] ... [2, 512, 16, 16, 6]
|
| 169 |
+
return skips, outs # latent_embeddings(a list of multiscale features), perpixel_embeddings(a list of decoder outputs)
|
| 170 |
+
|
| 171 |
+
def compute_conv_feature_map_size(self, input_size):
|
| 172 |
+
assert len(input_size) == convert_conv_op_to_dim(self.encoder.conv_op), "just give the image size without color/feature channels or " \
|
| 173 |
+
"batch channel. Do not give input_size=(b, c, x, y(, z)). " \
|
| 174 |
+
"Give input_size=(x, y(, z))!"
|
| 175 |
+
return self.encoder.compute_conv_feature_map_size(input_size) + self.decoder.compute_conv_feature_map_size(input_size)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
if __name__ == '__main__':
|
| 179 |
+
import sys
|
| 180 |
+
sys.path.append('/remote-home/zihengzhao/Knowledge-Enhanced-Medical-Segmentation/medical-universal-segmentation/model/dynamic-network-architectures-main')
|
| 181 |
+
|
| 182 |
+
data = torch.rand((2, 3, 256, 256, 96)).cuda()
|
| 183 |
+
|
| 184 |
+
model = PlainConvUNet(3, 6, (32, 64, 128, 256, 512, 768), nn.Conv3d, 3, (1, 2, 2, 2, 2, 2), (2, 2, 2, 2, 2, 2), 4,
|
| 185 |
+
(2, 2, 2, 2, 2), False, nn.BatchNorm3d, None, None, None, nn.ReLU, deep_supervision=True).cuda()
|
| 186 |
+
|
| 187 |
+
dec_outs, enc_outs = model(data)
|
| 188 |
+
print('DEC')
|
| 189 |
+
for i in dec_outs:
|
| 190 |
+
print(i.shape) # (2, 4, 256, 256, 96)
|
| 191 |
+
print('ENC')
|
| 192 |
+
for i in enc_outs:
|
| 193 |
+
print(i.shape) # ()
|
| 194 |
+
exit()
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
if False:
|
| 198 |
+
import hiddenlayer as hl
|
| 199 |
+
|
| 200 |
+
g = hl.build_graph(model, data,
|
| 201 |
+
transforms=None)
|
| 202 |
+
g.save("network_architecture.pdf")
|
| 203 |
+
del g
|
| 204 |
+
|
| 205 |
+
print(model.compute_conv_feature_map_size(data.shape[2:]))
|
| 206 |
+
|
| 207 |
+
data = torch.rand((1, 4, 512, 512))
|
| 208 |
+
|
| 209 |
+
model = PlainConvUNet(4, 8, (32, 64, 125, 256, 512, 512, 512, 512), nn.Conv2d, 3, (1, 2, 2, 2, 2, 2, 2, 2), (2, 2, 2, 2, 2, 2, 2, 2), 4,
|
| 210 |
+
(2, 2, 2, 2, 2, 2, 2), False, nn.BatchNorm2d, None, None, None, nn.ReLU, deep_supervision=True)
|
| 211 |
+
|
| 212 |
+
if False:
|
| 213 |
+
import hiddenlayer as hl
|
| 214 |
+
|
| 215 |
+
g = hl.build_graph(model, data,
|
| 216 |
+
transforms=None)
|
| 217 |
+
g.save("network_architecture.pdf")
|
| 218 |
+
del g
|
| 219 |
+
|
| 220 |
+
print(model.compute_conv_feature_map_size(data.shape[2:]))
|
model/dynamic-network-architectures-main/dynamic_network_architectures/architectures/vgg.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
|
| 4 |
+
from dynamic_network_architectures.building_blocks.plain_conv_encoder import PlainConvEncoder
|
| 5 |
+
from dynamic_network_architectures.building_blocks.helper import get_matching_pool_op, get_default_network_config
|
| 6 |
+
|
| 7 |
+
_VGG_CONFIGS = {
|
| 8 |
+
'16': {'features_per_stage': (64, 128, 256, 512, 512, 512), 'n_conv_per_stage': (2, 2, 2, 3, 3, 3),
|
| 9 |
+
'strides': (1, 2, 2, 2, 2, 2)},
|
| 10 |
+
'19': {'features_per_stage': (64, 128, 256, 512, 512, 512), 'n_conv_per_stage': (2, 2, 3, 3, 4, 4),
|
| 11 |
+
'strides': (1, 2, 2, 2, 2, 2)},
|
| 12 |
+
'16_cifar': {'features_per_stage': (64, 128, 256, 512), 'n_conv_per_stage': (2, 3, 5, 5), 'strides': (1, 2, 2, 2)},
|
| 13 |
+
'19_cifar': {'features_per_stage': (64, 128, 256, 512), 'n_conv_per_stage': (3, 4, 5, 6), 'strides': (1, 2, 2, 2)},
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
_VGG_OPS = {
|
| 17 |
+
1: {'conv_op': nn.Conv1d, 'norm_op': nn.BatchNorm1d},
|
| 18 |
+
2: {'conv_op': nn.Conv2d, 'norm_op': nn.BatchNorm2d},
|
| 19 |
+
3: {'conv_op': nn.Conv3d, 'norm_op': nn.BatchNorm3d},
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class VGG(nn.Module):
|
| 24 |
+
def __init__(self, n_classes: int, n_input_channel: int = 3, config='16', input_dimension=2):
|
| 25 |
+
"""
|
| 26 |
+
This is not 1:1 VGG because it does not have the bloated fully connected layers at the end. Since these were
|
| 27 |
+
counted towards the XX layers as well, we increase the number of convolutional layers so that we have the
|
| 28 |
+
desired number of conv layers in total
|
| 29 |
+
|
| 30 |
+
We also use batchnorm
|
| 31 |
+
"""
|
| 32 |
+
super().__init__()
|
| 33 |
+
cfg = _VGG_CONFIGS[config]
|
| 34 |
+
ops = get_default_network_config(dimension=input_dimension)
|
| 35 |
+
self.encoder = PlainConvEncoder(
|
| 36 |
+
n_input_channel, n_stages=len(cfg['features_per_stage']), features_per_stage=cfg['features_per_stage'],
|
| 37 |
+
conv_op=ops['conv_op'],
|
| 38 |
+
kernel_sizes=3, strides=cfg['strides'], n_conv_per_stage=cfg['n_conv_per_stage'], conv_bias=False,
|
| 39 |
+
norm_op=ops['norm_op'], norm_op_kwargs=None, dropout_op=None, dropout_op_kwargs=None, nonlin=nn.ReLU,
|
| 40 |
+
nonlin_kwargs={'inplace': True}, return_skips=False
|
| 41 |
+
)
|
| 42 |
+
self.gap = get_matching_pool_op(conv_op=ops['conv_op'], adaptive=True, pool_type='avg')(1)
|
| 43 |
+
self.classifier = nn.Linear(cfg['features_per_stage'][-1], n_classes, True)
|
| 44 |
+
|
| 45 |
+
def forward(self, x):
|
| 46 |
+
x = self.encoder(x)
|
| 47 |
+
x = self.gap(x).squeeze()
|
| 48 |
+
return self.classifier(x)
|
| 49 |
+
|
| 50 |
+
def compute_conv_feature_map_size(self, input_size):
|
| 51 |
+
return self.encoder.compute_conv_feature_map_size(input_size)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class VGG16(VGG):
|
| 55 |
+
def __init__(self, n_classes: int, n_input_channel: int = 3, input_dimension: int = 2):
|
| 56 |
+
super().__init__(n_classes, n_input_channel, config='16', input_dimension=input_dimension)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class VGG19(VGG):
|
| 60 |
+
def __init__(self, n_classes: int, n_input_channel: int = 3, input_dimension: int = 2):
|
| 61 |
+
super().__init__(n_classes, n_input_channel, config='19', input_dimension=input_dimension)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class VGG16_cifar(VGG):
|
| 65 |
+
def __init__(self, n_classes: int, n_input_channel: int = 3, input_dimension: int = 2):
|
| 66 |
+
super().__init__(n_classes, n_input_channel, config='16_cifar', input_dimension=input_dimension)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class VGG19_cifar(VGG):
|
| 70 |
+
def __init__(self, n_classes: int, n_input_channel: int = 3, input_dimension: int = 2):
|
| 71 |
+
super().__init__(n_classes, n_input_channel, config='19_cifar', input_dimension=input_dimension)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
if __name__ == '__main__':
|
| 75 |
+
data = torch.rand((1, 3, 32, 32))
|
| 76 |
+
|
| 77 |
+
model = VGG19_cifar(10, 3)
|
| 78 |
+
import hiddenlayer as hl
|
| 79 |
+
|
| 80 |
+
g = hl.build_graph(model, data,
|
| 81 |
+
transforms=None)
|
| 82 |
+
g.save("network_architecture.pdf")
|
| 83 |
+
del g
|
| 84 |
+
|
| 85 |
+
print(model.compute_conv_feature_map_size((32, 32)))
|
model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__init__.py
ADDED
|
File without changes
|
model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (272 Bytes). View file
|
|
|
model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/helper.cpython-310.pyc
ADDED
|
Binary file (5.93 kB). View file
|
|
|
model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/plain_conv_encoder.cpython-310.pyc
ADDED
|
Binary file (4.22 kB). View file
|
|
|
model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/regularization.cpython-310.pyc
ADDED
|
Binary file (4.39 kB). View file
|
|
|
model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/residual.cpython-310.pyc
ADDED
|
Binary file (14.2 kB). View file
|
|
|
model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/residual_encoders.cpython-310.pyc
ADDED
|
Binary file (6.39 kB). View file
|
|
|
model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/simple_conv_blocks.cpython-310.pyc
ADDED
|
Binary file (5.85 kB). View file
|
|
|
model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/unet_decoder.cpython-310.pyc
ADDED
|
Binary file (6.85 kB). View file
|
|
|
model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/helper.py
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Type
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch.nn
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
| 6 |
+
from torch.nn.modules.conv import _ConvNd, _ConvTransposeNd
|
| 7 |
+
from torch.nn.modules.dropout import _DropoutNd
|
| 8 |
+
from torch.nn.modules.instancenorm import _InstanceNorm
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def convert_dim_to_conv_op(dimension: int) -> Type[_ConvNd]:
|
| 12 |
+
"""
|
| 13 |
+
:param dimension: 1, 2 or 3
|
| 14 |
+
:return: conv Class of corresponding dimension
|
| 15 |
+
"""
|
| 16 |
+
if dimension == 1:
|
| 17 |
+
return nn.Conv1d
|
| 18 |
+
elif dimension == 2:
|
| 19 |
+
return nn.Conv2d
|
| 20 |
+
elif dimension == 3:
|
| 21 |
+
return nn.Conv3d
|
| 22 |
+
else:
|
| 23 |
+
raise ValueError("Unknown dimension. Only 1, 2 and 3 are supported")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def convert_conv_op_to_dim(conv_op: Type[_ConvNd]) -> int:
|
| 27 |
+
"""
|
| 28 |
+
:param conv_op: conv class
|
| 29 |
+
:return: dimension: 1, 2 or 3
|
| 30 |
+
"""
|
| 31 |
+
if conv_op == nn.Conv1d:
|
| 32 |
+
return 1
|
| 33 |
+
elif conv_op == nn.Conv2d:
|
| 34 |
+
return 2
|
| 35 |
+
elif conv_op == nn.Conv3d:
|
| 36 |
+
return 3
|
| 37 |
+
else:
|
| 38 |
+
raise ValueError("Unknown dimension. Only 1d 2d and 3d conv are supported. got %s" % str(conv_op))
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def get_matching_pool_op(conv_op: Type[_ConvNd] = None,
|
| 42 |
+
dimension: int = None,
|
| 43 |
+
adaptive=False,
|
| 44 |
+
pool_type: str = 'avg') -> Type[torch.nn.Module]:
|
| 45 |
+
"""
|
| 46 |
+
You MUST set EITHER conv_op OR dimension. Do not set both!
|
| 47 |
+
:param conv_op:
|
| 48 |
+
:param dimension:
|
| 49 |
+
:param adaptive:
|
| 50 |
+
:param pool_type: either 'avg' or 'max'
|
| 51 |
+
:return:
|
| 52 |
+
"""
|
| 53 |
+
assert not ((conv_op is not None) and (dimension is not None)), \
|
| 54 |
+
"You MUST set EITHER conv_op OR dimension. Do not set both!"
|
| 55 |
+
assert pool_type in ['avg', 'max'], 'pool_type must be either avg or max'
|
| 56 |
+
if conv_op is not None:
|
| 57 |
+
dimension = convert_conv_op_to_dim(conv_op)
|
| 58 |
+
assert dimension in [1, 2, 3], 'Dimension must be 1, 2 or 3'
|
| 59 |
+
|
| 60 |
+
if conv_op is not None:
|
| 61 |
+
dimension = convert_conv_op_to_dim(conv_op)
|
| 62 |
+
|
| 63 |
+
if dimension == 1:
|
| 64 |
+
if pool_type == 'avg':
|
| 65 |
+
if adaptive:
|
| 66 |
+
return nn.AdaptiveAvgPool1d
|
| 67 |
+
else:
|
| 68 |
+
return nn.AvgPool1d
|
| 69 |
+
elif pool_type == 'max':
|
| 70 |
+
if adaptive:
|
| 71 |
+
return nn.AdaptiveMaxPool1d
|
| 72 |
+
else:
|
| 73 |
+
return nn.MaxPool1d
|
| 74 |
+
elif dimension == 2:
|
| 75 |
+
if pool_type == 'avg':
|
| 76 |
+
if adaptive:
|
| 77 |
+
return nn.AdaptiveAvgPool2d
|
| 78 |
+
else:
|
| 79 |
+
return nn.AvgPool2d
|
| 80 |
+
elif pool_type == 'max':
|
| 81 |
+
if adaptive:
|
| 82 |
+
return nn.AdaptiveMaxPool2d
|
| 83 |
+
else:
|
| 84 |
+
return nn.MaxPool2d
|
| 85 |
+
elif dimension == 3:
|
| 86 |
+
if pool_type == 'avg':
|
| 87 |
+
if adaptive:
|
| 88 |
+
return nn.AdaptiveAvgPool3d
|
| 89 |
+
else:
|
| 90 |
+
return nn.AvgPool3d
|
| 91 |
+
elif pool_type == 'max':
|
| 92 |
+
if adaptive:
|
| 93 |
+
return nn.AdaptiveMaxPool3d
|
| 94 |
+
else:
|
| 95 |
+
return nn.MaxPool3d
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def get_matching_instancenorm(conv_op: Type[_ConvNd] = None, dimension: int = None) -> Type[_InstanceNorm]:
|
| 99 |
+
"""
|
| 100 |
+
You MUST set EITHER conv_op OR dimension. Do not set both!
|
| 101 |
+
|
| 102 |
+
:param conv_op:
|
| 103 |
+
:param dimension:
|
| 104 |
+
:return:
|
| 105 |
+
"""
|
| 106 |
+
assert not ((conv_op is not None) and (dimension is not None)), \
|
| 107 |
+
"You MUST set EITHER conv_op OR dimension. Do not set both!"
|
| 108 |
+
if conv_op is not None:
|
| 109 |
+
dimension = convert_conv_op_to_dim(conv_op)
|
| 110 |
+
if dimension is not None:
|
| 111 |
+
assert dimension in [1, 2, 3], 'Dimension must be 1, 2 or 3'
|
| 112 |
+
if dimension == 1:
|
| 113 |
+
return nn.InstanceNorm1d
|
| 114 |
+
elif dimension == 2:
|
| 115 |
+
return nn.InstanceNorm2d
|
| 116 |
+
elif dimension == 3:
|
| 117 |
+
return nn.InstanceNorm3d
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def get_matching_convtransp(conv_op: Type[_ConvNd] = None, dimension: int = None) -> Type[_ConvTransposeNd]:
|
| 121 |
+
"""
|
| 122 |
+
You MUST set EITHER conv_op OR dimension. Do not set both!
|
| 123 |
+
|
| 124 |
+
:param conv_op:
|
| 125 |
+
:param dimension:
|
| 126 |
+
:return:
|
| 127 |
+
"""
|
| 128 |
+
assert not ((conv_op is not None) and (dimension is not None)), \
|
| 129 |
+
"You MUST set EITHER conv_op OR dimension. Do not set both!"
|
| 130 |
+
if conv_op is not None:
|
| 131 |
+
dimension = convert_conv_op_to_dim(conv_op)
|
| 132 |
+
assert dimension in [1, 2, 3], 'Dimension must be 1, 2 or 3'
|
| 133 |
+
if dimension == 1:
|
| 134 |
+
return nn.ConvTranspose1d
|
| 135 |
+
elif dimension == 2:
|
| 136 |
+
return nn.ConvTranspose2d
|
| 137 |
+
elif dimension == 3:
|
| 138 |
+
return nn.ConvTranspose3d
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def get_matching_batchnorm(conv_op: Type[_ConvNd] = None, dimension: int = None) -> Type[_BatchNorm]:
|
| 142 |
+
"""
|
| 143 |
+
You MUST set EITHER conv_op OR dimension. Do not set both!
|
| 144 |
+
|
| 145 |
+
:param conv_op:
|
| 146 |
+
:param dimension:
|
| 147 |
+
:return:
|
| 148 |
+
"""
|
| 149 |
+
assert not ((conv_op is not None) and (dimension is not None)), \
|
| 150 |
+
"You MUST set EITHER conv_op OR dimension. Do not set both!"
|
| 151 |
+
if conv_op is not None:
|
| 152 |
+
dimension = convert_conv_op_to_dim(conv_op)
|
| 153 |
+
assert dimension in [1, 2, 3], 'Dimension must be 1, 2 or 3'
|
| 154 |
+
if dimension == 1:
|
| 155 |
+
return nn.BatchNorm1d
|
| 156 |
+
elif dimension == 2:
|
| 157 |
+
return nn.BatchNorm2d
|
| 158 |
+
elif dimension == 3:
|
| 159 |
+
return nn.BatchNorm3d
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def get_matching_dropout(conv_op: Type[_ConvNd] = None, dimension: int = None) -> Type[_DropoutNd]:
|
| 163 |
+
"""
|
| 164 |
+
You MUST set EITHER conv_op OR dimension. Do not set both!
|
| 165 |
+
|
| 166 |
+
:param conv_op:
|
| 167 |
+
:param dimension:
|
| 168 |
+
:return:
|
| 169 |
+
"""
|
| 170 |
+
assert not ((conv_op is not None) and (dimension is not None)), \
|
| 171 |
+
"You MUST set EITHER conv_op OR dimension. Do not set both!"
|
| 172 |
+
assert dimension in [1, 2, 3], 'Dimension must be 1, 2 or 3'
|
| 173 |
+
if dimension == 1:
|
| 174 |
+
return nn.Dropout
|
| 175 |
+
elif dimension == 2:
|
| 176 |
+
return nn.Dropout2d
|
| 177 |
+
elif dimension == 3:
|
| 178 |
+
return nn.Dropout3d
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def maybe_convert_scalar_to_list(conv_op, scalar):
|
| 182 |
+
"""
|
| 183 |
+
useful for converting, for example, kernel_size=3 to [3, 3, 3] in case of nn.Conv3d
|
| 184 |
+
:param conv_op:
|
| 185 |
+
:param scalar:
|
| 186 |
+
:return:
|
| 187 |
+
"""
|
| 188 |
+
if not isinstance(scalar, (tuple, list, np.ndarray)):
|
| 189 |
+
if conv_op == nn.Conv2d:
|
| 190 |
+
return [scalar] * 2
|
| 191 |
+
elif conv_op == nn.Conv3d:
|
| 192 |
+
return [scalar] * 3
|
| 193 |
+
elif conv_op == nn.Conv1d:
|
| 194 |
+
return [scalar] * 1
|
| 195 |
+
else:
|
| 196 |
+
raise RuntimeError("Invalid conv op: %s" % str(conv_op))
|
| 197 |
+
else:
|
| 198 |
+
return scalar
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def get_default_network_config(dimension: int = 2,
|
| 202 |
+
nonlin: str = "ReLU",
|
| 203 |
+
norm_type: str = "bn") -> dict:
|
| 204 |
+
"""
|
| 205 |
+
Use this to get a standard configuration. A network configuration looks like this:
|
| 206 |
+
|
| 207 |
+
config = {'conv_op': torch.nn.modules.conv.Conv2d,
|
| 208 |
+
'dropout_op': torch.nn.modules.dropout.Dropout2d,
|
| 209 |
+
'norm_op': torch.nn.modules.batchnorm.BatchNorm2d,
|
| 210 |
+
'norm_op_kwargs': {'eps': 1e-05, 'affine': True},
|
| 211 |
+
'nonlin': torch.nn.modules.activation.ReLU,
|
| 212 |
+
'nonlin_kwargs': {'inplace': True}}
|
| 213 |
+
|
| 214 |
+
There is no need to use get_default_network_config. You can create your own. Network configs are a convenient way of
|
| 215 |
+
setting dimensionality, normalization and nonlinearity.
|
| 216 |
+
|
| 217 |
+
:param dimension: integer denoting the dimension of the data. 1, 2 and 3 are accepted
|
| 218 |
+
:param nonlin: string (ReLU or LeakyReLU)
|
| 219 |
+
:param norm_type: string (bn=batch norm, in=instance norm)
|
| 220 |
+
torch.nn.Module
|
| 221 |
+
:return: dict
|
| 222 |
+
"""
|
| 223 |
+
config = {}
|
| 224 |
+
config['conv_op'] = convert_dim_to_conv_op(dimension)
|
| 225 |
+
config['dropout_op'] = get_matching_dropout(dimension=dimension)
|
| 226 |
+
if norm_type == "bn":
|
| 227 |
+
config['norm_op'] = get_matching_batchnorm(dimension=dimension)
|
| 228 |
+
elif norm_type == "in":
|
| 229 |
+
config['norm_op'] = get_matching_instancenorm(dimension=dimension)
|
| 230 |
+
|
| 231 |
+
config['norm_op_kwargs'] = None # this will use defaults
|
| 232 |
+
|
| 233 |
+
if nonlin == "LeakyReLU":
|
| 234 |
+
config['nonlin'] = nn.LeakyReLU
|
| 235 |
+
config['nonlin_kwargs'] = {'negative_slope': 1e-2, 'inplace': True}
|
| 236 |
+
elif nonlin == "ReLU":
|
| 237 |
+
config['nonlin'] = nn.ReLU
|
| 238 |
+
config['nonlin_kwargs'] = {'inplace': True}
|
| 239 |
+
else:
|
| 240 |
+
raise NotImplementedError('Unknown nonlin %s. Only "LeakyReLU" and "ReLU" are supported for now' % nonlin)
|
| 241 |
+
|
| 242 |
+
return config
|
model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/plain_conv_encoder.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
import numpy as np
|
| 4 |
+
from typing import Union, Type, List, Tuple
|
| 5 |
+
|
| 6 |
+
from torch.nn.modules.conv import _ConvNd
|
| 7 |
+
from torch.nn.modules.dropout import _DropoutNd
|
| 8 |
+
from dynamic_network_architectures.building_blocks.simple_conv_blocks import StackedConvBlocks
|
| 9 |
+
from dynamic_network_architectures.building_blocks.helper import maybe_convert_scalar_to_list, get_matching_pool_op
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class PlainConvEncoder(nn.Module):
|
| 13 |
+
def __init__(self,
|
| 14 |
+
input_channels: int,
|
| 15 |
+
n_stages: int,
|
| 16 |
+
features_per_stage: Union[int, List[int], Tuple[int, ...]],
|
| 17 |
+
conv_op: Type[_ConvNd],
|
| 18 |
+
kernel_sizes: Union[int, List[int], Tuple[int, ...]],
|
| 19 |
+
strides: Union[int, List[int], Tuple[int, ...]],
|
| 20 |
+
n_conv_per_stage: Union[int, List[int], Tuple[int, ...]],
|
| 21 |
+
conv_bias: bool = False,
|
| 22 |
+
norm_op: Union[None, Type[nn.Module]] = None,
|
| 23 |
+
norm_op_kwargs: dict = None,
|
| 24 |
+
dropout_op: Union[None, Type[_DropoutNd]] = None,
|
| 25 |
+
dropout_op_kwargs: dict = None,
|
| 26 |
+
nonlin: Union[None, Type[torch.nn.Module]] = None,
|
| 27 |
+
nonlin_kwargs: dict = None,
|
| 28 |
+
return_skips: bool = False,
|
| 29 |
+
nonlin_first: bool = False,
|
| 30 |
+
pool: str = 'conv'
|
| 31 |
+
):
|
| 32 |
+
|
| 33 |
+
super().__init__()
|
| 34 |
+
if isinstance(kernel_sizes, int):
|
| 35 |
+
kernel_sizes = [kernel_sizes] * n_stages
|
| 36 |
+
if isinstance(features_per_stage, int):
|
| 37 |
+
features_per_stage = [features_per_stage] * n_stages
|
| 38 |
+
if isinstance(n_conv_per_stage, int):
|
| 39 |
+
n_conv_per_stage = [n_conv_per_stage] * n_stages
|
| 40 |
+
if isinstance(strides, int):
|
| 41 |
+
strides = [strides] * n_stages
|
| 42 |
+
assert len(kernel_sizes) == n_stages, "kernel_sizes must have as many entries as we have resolution stages (n_stages)"
|
| 43 |
+
assert len(n_conv_per_stage) == n_stages, "n_conv_per_stage must have as many entries as we have resolution stages (n_stages)"
|
| 44 |
+
assert len(features_per_stage) == n_stages, "features_per_stage must have as many entries as we have resolution stages (n_stages)"
|
| 45 |
+
assert len(strides) == n_stages, "strides must have as many entries as we have resolution stages (n_stages). " \
|
| 46 |
+
"Important: first entry is recommended to be 1, else we run strided conv drectly on the input"
|
| 47 |
+
|
| 48 |
+
stages = []
|
| 49 |
+
for s in range(n_stages):
|
| 50 |
+
stage_modules = []
|
| 51 |
+
if pool == 'max' or pool == 'avg':
|
| 52 |
+
if (isinstance(strides[s], int) and strides[s] != 1) or \
|
| 53 |
+
isinstance(strides[s], (tuple, list)) and any([i != 1 for i in strides[s]]):
|
| 54 |
+
stage_modules.append(get_matching_pool_op(conv_op, pool_type=pool)(kernel_size=strides[s], stride=strides[s]))
|
| 55 |
+
conv_stride = 1
|
| 56 |
+
elif pool == 'conv':
|
| 57 |
+
conv_stride = strides[s]
|
| 58 |
+
else:
|
| 59 |
+
raise RuntimeError()
|
| 60 |
+
stage_modules.append(StackedConvBlocks(
|
| 61 |
+
n_conv_per_stage[s], conv_op, input_channels, features_per_stage[s], kernel_sizes[s], conv_stride,
|
| 62 |
+
conv_bias, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs, nonlin_first
|
| 63 |
+
))
|
| 64 |
+
stages.append(nn.Sequential(*stage_modules))
|
| 65 |
+
input_channels = features_per_stage[s]
|
| 66 |
+
|
| 67 |
+
self.stages = nn.Sequential(*stages)
|
| 68 |
+
self.output_channels = features_per_stage
|
| 69 |
+
self.strides = [maybe_convert_scalar_to_list(conv_op, i) for i in strides]
|
| 70 |
+
self.return_skips = return_skips
|
| 71 |
+
|
| 72 |
+
# we store some things that a potential decoder needs
|
| 73 |
+
self.conv_op = conv_op
|
| 74 |
+
self.norm_op = norm_op
|
| 75 |
+
self.norm_op_kwargs = norm_op_kwargs
|
| 76 |
+
self.nonlin = nonlin
|
| 77 |
+
self.nonlin_kwargs = nonlin_kwargs
|
| 78 |
+
self.dropout_op = dropout_op
|
| 79 |
+
self.dropout_op_kwargs = dropout_op_kwargs
|
| 80 |
+
self.conv_bias = conv_bias
|
| 81 |
+
self.kernel_sizes = kernel_sizes
|
| 82 |
+
|
| 83 |
+
def forward(self, x):
|
| 84 |
+
ret = []
|
| 85 |
+
for s in self.stages:
|
| 86 |
+
x = s(x)
|
| 87 |
+
ret.append(x)
|
| 88 |
+
if self.return_skips:
|
| 89 |
+
return ret
|
| 90 |
+
else:
|
| 91 |
+
return ret[-1]
|
| 92 |
+
|
| 93 |
+
def compute_conv_feature_map_size(self, input_size):
|
| 94 |
+
output = np.int64(0)
|
| 95 |
+
for s in range(len(self.stages)):
|
| 96 |
+
if isinstance(self.stages[s], nn.Sequential):
|
| 97 |
+
for sq in self.stages[s]:
|
| 98 |
+
if hasattr(sq, 'compute_conv_feature_map_size'):
|
| 99 |
+
output += self.stages[s][-1].compute_conv_feature_map_size(input_size)
|
| 100 |
+
else:
|
| 101 |
+
output += self.stages[s].compute_conv_feature_map_size(input_size)
|
| 102 |
+
input_size = [i // j for i, j in zip(input_size, self.strides[s])]
|
| 103 |
+
return output
|
| 104 |
+
|
| 105 |
+
|
model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/regularization.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
|
| 5 |
+
"""
|
| 6 |
+
This function is taken from the timm package (https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py).
|
| 7 |
+
|
| 8 |
+
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 9 |
+
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
| 10 |
+
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
| 11 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
| 12 |
+
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
| 13 |
+
'survival rate' as the argument.
|
| 14 |
+
"""
|
| 15 |
+
if drop_prob == 0. or not training:
|
| 16 |
+
return x
|
| 17 |
+
keep_prob = 1 - drop_prob
|
| 18 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
| 19 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 20 |
+
if keep_prob > 0.0 and scale_by_keep:
|
| 21 |
+
random_tensor.div_(keep_prob)
|
| 22 |
+
return x * random_tensor
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class DropPath(nn.Module):
|
| 26 |
+
"""
|
| 27 |
+
This class is taken from the timm package (https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py).
|
| 28 |
+
|
| 29 |
+
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 30 |
+
"""
|
| 31 |
+
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
|
| 32 |
+
super(DropPath, self).__init__()
|
| 33 |
+
self.drop_prob = drop_prob
|
| 34 |
+
self.scale_by_keep = scale_by_keep
|
| 35 |
+
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class SqueezeExcite(nn.Module):
|
| 41 |
+
"""
|
| 42 |
+
This class is taken from the timm package (https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/squeeze_excite.py)
|
| 43 |
+
and slightly modified so that the convolution type can be adapted.
|
| 44 |
+
|
| 45 |
+
SE Module as defined in original SE-Nets with a few additions
|
| 46 |
+
Additions include:
|
| 47 |
+
* divisor can be specified to keep channels % div == 0 (default: 8)
|
| 48 |
+
* reduction channels can be specified directly by arg (if rd_channels is set)
|
| 49 |
+
* reduction channels can be specified by float rd_ratio (default: 1/16)
|
| 50 |
+
* global max pooling can be added to the squeeze aggregation
|
| 51 |
+
* customizable activation, normalization, and gate layer
|
| 52 |
+
"""
|
| 53 |
+
def __init__(
|
| 54 |
+
self, channels, conv_op, rd_ratio=1. / 16, rd_channels=None, rd_divisor=8, add_maxpool=False,
|
| 55 |
+
act_layer=nn.ReLU, norm_layer=None, gate_layer=nn.Sigmoid):
|
| 56 |
+
super(SqueezeExcite, self).__init__()
|
| 57 |
+
self.add_maxpool = add_maxpool
|
| 58 |
+
if not rd_channels:
|
| 59 |
+
rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
|
| 60 |
+
self.fc1 = conv_op(channels, rd_channels, kernel_size=1, bias=True)
|
| 61 |
+
self.bn = norm_layer(rd_channels) if norm_layer else nn.Identity()
|
| 62 |
+
self.act = act_layer(inplace=True)
|
| 63 |
+
self.fc2 = conv_op(rd_channels, channels, kernel_size=1, bias=True)
|
| 64 |
+
self.gate = gate_layer()
|
| 65 |
+
|
| 66 |
+
def forward(self, x):
|
| 67 |
+
x_se = x.mean((2, 3), keepdim=True)
|
| 68 |
+
if self.add_maxpool:
|
| 69 |
+
# experimental codepath, may remove or change
|
| 70 |
+
x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True)
|
| 71 |
+
x_se = self.fc1(x_se)
|
| 72 |
+
x_se = self.act(self.bn(x_se))
|
| 73 |
+
x_se = self.fc2(x_se)
|
| 74 |
+
return x * self.gate(x_se)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def make_divisible(v, divisor=8, min_value=None, round_limit=.9):
|
| 78 |
+
"""
|
| 79 |
+
This function is taken from the timm package (https://github.com/rwightman/pytorch-image-models/blob/b7cb8d0337b3e7b50516849805ddb9be5fc11644/timm/models/layers/helpers.py#L25)
|
| 80 |
+
"""
|
| 81 |
+
min_value = min_value or divisor
|
| 82 |
+
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
| 83 |
+
# Make sure that round down does not go down by more than 10%.
|
| 84 |
+
if new_v < round_limit * v:
|
| 85 |
+
new_v += divisor
|
| 86 |
+
return new_v
|
model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/residual.py
ADDED
|
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple, List, Union, Type
|
| 2 |
+
import torch.nn
|
| 3 |
+
from torch import nn
|
| 4 |
+
from torch.nn.modules.conv import _ConvNd
|
| 5 |
+
from torch.nn.modules.dropout import _DropoutNd
|
| 6 |
+
|
| 7 |
+
from dynamic_network_architectures.building_blocks.helper import maybe_convert_scalar_to_list, get_matching_pool_op
|
| 8 |
+
from dynamic_network_architectures.building_blocks.simple_conv_blocks import ConvDropoutNormReLU
|
| 9 |
+
from dynamic_network_architectures.building_blocks.regularization import DropPath, SqueezeExcite
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class BasicBlockD(nn.Module):
|
| 14 |
+
def __init__(self,
|
| 15 |
+
conv_op: Type[_ConvNd],
|
| 16 |
+
input_channels: int,
|
| 17 |
+
output_channels: int,
|
| 18 |
+
kernel_size: Union[int, List[int], Tuple[int, ...]],
|
| 19 |
+
stride: Union[int, List[int], Tuple[int, ...]],
|
| 20 |
+
conv_bias: bool = False,
|
| 21 |
+
norm_op: Union[None, Type[nn.Module]] = None,
|
| 22 |
+
norm_op_kwargs: dict = None,
|
| 23 |
+
dropout_op: Union[None, Type[_DropoutNd]] = None,
|
| 24 |
+
dropout_op_kwargs: dict = None,
|
| 25 |
+
nonlin: Union[None, Type[torch.nn.Module]] = None,
|
| 26 |
+
nonlin_kwargs: dict = None,
|
| 27 |
+
stochastic_depth_p: float = 0.0,
|
| 28 |
+
squeeze_excitation: bool = False,
|
| 29 |
+
squeeze_excitation_reduction_ratio: float = 1. / 16,
|
| 30 |
+
# todo wideresnet?
|
| 31 |
+
):
|
| 32 |
+
"""
|
| 33 |
+
This implementation follows ResNet-D:
|
| 34 |
+
|
| 35 |
+
He, Tong, et al. "Bag of tricks for image classification with convolutional neural networks."
|
| 36 |
+
Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2019.
|
| 37 |
+
|
| 38 |
+
The skip has an avgpool (if needed) followed by 1x1 conv instead of just a strided 1x1 conv
|
| 39 |
+
|
| 40 |
+
:param conv_op:
|
| 41 |
+
:param input_channels:
|
| 42 |
+
:param output_channels:
|
| 43 |
+
:param kernel_size: refers only to convs in feature extraction path, not to 1x1x1 conv in skip
|
| 44 |
+
:param stride: only applies to first conv (and skip). Second conv always has stride 1
|
| 45 |
+
:param conv_bias:
|
| 46 |
+
:param norm_op:
|
| 47 |
+
:param norm_op_kwargs:
|
| 48 |
+
:param dropout_op: only the first conv can have dropout. The second never has
|
| 49 |
+
:param dropout_op_kwargs:
|
| 50 |
+
:param nonlin:
|
| 51 |
+
:param nonlin_kwargs:
|
| 52 |
+
:param stochastic_depth_p:
|
| 53 |
+
:param squeeze_excitation:
|
| 54 |
+
:param squeeze_excitation_reduction_ratio:
|
| 55 |
+
"""
|
| 56 |
+
super().__init__()
|
| 57 |
+
self.input_channels = input_channels
|
| 58 |
+
self.output_channels = output_channels
|
| 59 |
+
stride = maybe_convert_scalar_to_list(conv_op, stride)
|
| 60 |
+
self.stride = stride
|
| 61 |
+
|
| 62 |
+
kernel_size = maybe_convert_scalar_to_list(conv_op, kernel_size)
|
| 63 |
+
|
| 64 |
+
if norm_op_kwargs is None:
|
| 65 |
+
norm_op_kwargs = {}
|
| 66 |
+
if nonlin_kwargs is None:
|
| 67 |
+
nonlin_kwargs = {}
|
| 68 |
+
|
| 69 |
+
self.conv1 = ConvDropoutNormReLU(conv_op, input_channels, output_channels, kernel_size, stride, conv_bias,
|
| 70 |
+
norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs)
|
| 71 |
+
self.conv2 = ConvDropoutNormReLU(conv_op, output_channels, output_channels, kernel_size, 1, conv_bias, norm_op,
|
| 72 |
+
norm_op_kwargs, None, None, None, None)
|
| 73 |
+
|
| 74 |
+
self.nonlin2 = nonlin(**nonlin_kwargs) if nonlin is not None else lambda x: x
|
| 75 |
+
|
| 76 |
+
# Stochastic Depth
|
| 77 |
+
self.apply_stochastic_depth = False if stochastic_depth_p == 0.0 else True
|
| 78 |
+
if self.apply_stochastic_depth:
|
| 79 |
+
self.drop_path = DropPath(drop_prob=stochastic_depth_p)
|
| 80 |
+
|
| 81 |
+
# Squeeze Excitation
|
| 82 |
+
self.apply_se = squeeze_excitation
|
| 83 |
+
if self.apply_se:
|
| 84 |
+
self.squeeze_excitation = SqueezeExcite(self.output_channels, conv_op,
|
| 85 |
+
rd_ratio=squeeze_excitation_reduction_ratio, rd_divisor=8)
|
| 86 |
+
|
| 87 |
+
has_stride = (isinstance(stride, int) and stride != 1) or any([i != 1 for i in stride])
|
| 88 |
+
requires_projection = (input_channels != output_channels)
|
| 89 |
+
|
| 90 |
+
if has_stride or requires_projection:
|
| 91 |
+
ops = []
|
| 92 |
+
if has_stride:
|
| 93 |
+
ops.append(get_matching_pool_op(conv_op=conv_op, adaptive=False, pool_type='avg')(stride, stride))
|
| 94 |
+
if requires_projection:
|
| 95 |
+
ops.append(
|
| 96 |
+
ConvDropoutNormReLU(conv_op, input_channels, output_channels, 1, 1, False, norm_op,
|
| 97 |
+
norm_op_kwargs, None, None, None, None
|
| 98 |
+
)
|
| 99 |
+
)
|
| 100 |
+
self.skip = nn.Sequential(*ops)
|
| 101 |
+
else:
|
| 102 |
+
self.skip = lambda x: x
|
| 103 |
+
|
| 104 |
+
def forward(self, x):
|
| 105 |
+
residual = self.skip(x)
|
| 106 |
+
out = self.conv2(self.conv1(x))
|
| 107 |
+
if self.apply_stochastic_depth:
|
| 108 |
+
out = self.drop_path(out)
|
| 109 |
+
if self.apply_se:
|
| 110 |
+
out = self.squeeze_excitation(out)
|
| 111 |
+
out += residual
|
| 112 |
+
return self.nonlin2(out)
|
| 113 |
+
|
| 114 |
+
def compute_conv_feature_map_size(self, input_size):
|
| 115 |
+
assert len(input_size) == len(self.stride), "just give the image size without color/feature channels or " \
|
| 116 |
+
"batch channel. Do not give input_size=(b, c, x, y(, z)). " \
|
| 117 |
+
"Give input_size=(x, y(, z))!"
|
| 118 |
+
size_after_stride = [i // j for i, j in zip(input_size, self.stride)]
|
| 119 |
+
# conv1
|
| 120 |
+
output_size_conv1 = np.prod([self.output_channels, *size_after_stride], dtype=np.int64)
|
| 121 |
+
# conv2
|
| 122 |
+
output_size_conv2 = np.prod([self.output_channels, *size_after_stride], dtype=np.int64)
|
| 123 |
+
# skip conv (if applicable)
|
| 124 |
+
if (self.input_channels != self.output_channels) or any([i != j for i, j in zip(input_size, size_after_stride)]):
|
| 125 |
+
assert isinstance(self.skip, nn.Sequential)
|
| 126 |
+
output_size_skip = np.prod([self.output_channels, *size_after_stride], dtype=np.int64)
|
| 127 |
+
else:
|
| 128 |
+
assert not isinstance(self.skip, nn.Sequential)
|
| 129 |
+
output_size_skip = 0
|
| 130 |
+
return output_size_conv1 + output_size_conv2 + output_size_skip
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class BottleneckD(nn.Module):
|
| 134 |
+
def __init__(self,
|
| 135 |
+
conv_op: Type[_ConvNd],
|
| 136 |
+
input_channels: int,
|
| 137 |
+
bottleneck_channels: int,
|
| 138 |
+
output_channels: int,
|
| 139 |
+
kernel_size: Union[int, List[int], Tuple[int, ...]],
|
| 140 |
+
stride: Union[int, List[int], Tuple[int, ...]],
|
| 141 |
+
conv_bias: bool = False,
|
| 142 |
+
norm_op: Union[None, Type[nn.Module]] = None,
|
| 143 |
+
norm_op_kwargs: dict = None,
|
| 144 |
+
dropout_op: Union[None, Type[_DropoutNd]] = None,
|
| 145 |
+
dropout_op_kwargs: dict = None,
|
| 146 |
+
nonlin: Union[None, Type[torch.nn.Module]] = None,
|
| 147 |
+
nonlin_kwargs: dict = None,
|
| 148 |
+
stochastic_depth_p: float = 0.0,
|
| 149 |
+
squeeze_excitation: bool = False,
|
| 150 |
+
squeeze_excitation_reduction_ratio: float = 1. / 16
|
| 151 |
+
):
|
| 152 |
+
"""
|
| 153 |
+
This implementation follows ResNet-D:
|
| 154 |
+
|
| 155 |
+
He, Tong, et al. "Bag of tricks for image classification with convolutional neural networks."
|
| 156 |
+
Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2019.
|
| 157 |
+
|
| 158 |
+
The stride sits in the 3x3 conv instead of the 1x1 conv!
|
| 159 |
+
The skip has an avgpool (if needed) followed by 1x1 conv instead of just a strided 1x1 conv
|
| 160 |
+
|
| 161 |
+
:param conv_op:
|
| 162 |
+
:param input_channels:
|
| 163 |
+
:param output_channels:
|
| 164 |
+
:param kernel_size: only affects the conv in the middle (typically 3x3). The other convs remain 1x1
|
| 165 |
+
:param stride: only applies to the conv in the middle (and skip). Note that this deviates from the canonical
|
| 166 |
+
ResNet implementation where the stride is applied to the first 1x1 conv. (This implementation follows ResNet-D)
|
| 167 |
+
:param conv_bias:
|
| 168 |
+
:param norm_op:
|
| 169 |
+
:param norm_op_kwargs:
|
| 170 |
+
:param dropout_op: only the second (kernel_size) conv can have dropout. The first and last conv (1x1(x1)) never have it
|
| 171 |
+
:param dropout_op_kwargs:
|
| 172 |
+
:param nonlin:
|
| 173 |
+
:param nonlin_kwargs:
|
| 174 |
+
:param stochastic_depth_p:
|
| 175 |
+
:param squeeze_excitation:
|
| 176 |
+
:param squeeze_excitation_reduction_ratio:
|
| 177 |
+
"""
|
| 178 |
+
super().__init__()
|
| 179 |
+
self.input_channels = input_channels
|
| 180 |
+
self.output_channels = output_channels
|
| 181 |
+
self.bottleneck_channels = bottleneck_channels
|
| 182 |
+
stride = maybe_convert_scalar_to_list(conv_op, stride)
|
| 183 |
+
self.stride = stride
|
| 184 |
+
|
| 185 |
+
kernel_size = maybe_convert_scalar_to_list(conv_op, kernel_size)
|
| 186 |
+
if norm_op_kwargs is None:
|
| 187 |
+
norm_op_kwargs = {}
|
| 188 |
+
if nonlin_kwargs is None:
|
| 189 |
+
nonlin_kwargs = {}
|
| 190 |
+
|
| 191 |
+
self.conv1 = ConvDropoutNormReLU(conv_op, input_channels, bottleneck_channels, 1, 1, conv_bias,
|
| 192 |
+
norm_op, norm_op_kwargs, None, None, nonlin, nonlin_kwargs)
|
| 193 |
+
self.conv2 = ConvDropoutNormReLU(conv_op, bottleneck_channels, bottleneck_channels, kernel_size, stride,
|
| 194 |
+
conv_bias,
|
| 195 |
+
norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs)
|
| 196 |
+
self.conv3 = ConvDropoutNormReLU(conv_op, bottleneck_channels, output_channels, 1, 1, conv_bias, norm_op,
|
| 197 |
+
norm_op_kwargs, None, None, None, None)
|
| 198 |
+
|
| 199 |
+
self.nonlin3 = nonlin(**nonlin_kwargs) if nonlin is not None else lambda x: x
|
| 200 |
+
|
| 201 |
+
# Stochastic Depth
|
| 202 |
+
self.apply_stochastic_depth = False if stochastic_depth_p == 0.0 else True
|
| 203 |
+
if self.apply_stochastic_depth:
|
| 204 |
+
self.drop_path = DropPath(drop_prob=stochastic_depth_p)
|
| 205 |
+
|
| 206 |
+
# Squeeze Excitation
|
| 207 |
+
self.apply_se = squeeze_excitation
|
| 208 |
+
if self.apply_se:
|
| 209 |
+
self.squeeze_excitation = SqueezeExcite(self.output_channels, conv_op,
|
| 210 |
+
rd_ratio=squeeze_excitation_reduction_ratio, rd_divisor=8)
|
| 211 |
+
|
| 212 |
+
has_stride = (isinstance(stride, int) and stride != 1) or any([i != 1 for i in stride])
|
| 213 |
+
requires_projection = (input_channels != output_channels)
|
| 214 |
+
|
| 215 |
+
if has_stride or requires_projection:
|
| 216 |
+
ops = []
|
| 217 |
+
if has_stride:
|
| 218 |
+
ops.append(get_matching_pool_op(conv_op=conv_op, adaptive=False, pool_type='avg')(stride, stride))
|
| 219 |
+
if requires_projection:
|
| 220 |
+
ops.append(
|
| 221 |
+
ConvDropoutNormReLU(conv_op, input_channels, output_channels, 1, 1, False,
|
| 222 |
+
norm_op, norm_op_kwargs, None, None, None, None
|
| 223 |
+
)
|
| 224 |
+
)
|
| 225 |
+
self.skip = nn.Sequential(*ops)
|
| 226 |
+
else:
|
| 227 |
+
self.skip = lambda x: x
|
| 228 |
+
|
| 229 |
+
def forward(self, x):
|
| 230 |
+
residual = self.skip(x)
|
| 231 |
+
out = self.conv3(self.conv2(self.conv1(x)))
|
| 232 |
+
if self.apply_stochastic_depth:
|
| 233 |
+
out = self.drop_path(out)
|
| 234 |
+
if self.apply_se:
|
| 235 |
+
out = self.squeeze_excitation(out)
|
| 236 |
+
out += residual
|
| 237 |
+
return self.nonlin3(out)
|
| 238 |
+
|
| 239 |
+
def compute_conv_feature_map_size(self, input_size):
|
| 240 |
+
assert len(input_size) == len(self.stride), "just give the image size without color/feature channels or " \
|
| 241 |
+
"batch channel. Do not give input_size=(b, c, x, y(, z)). " \
|
| 242 |
+
"Give input_size=(x, y(, z))!"
|
| 243 |
+
size_after_stride = [i // j for i, j in zip(input_size, self.stride)]
|
| 244 |
+
# conv1
|
| 245 |
+
output_size_conv1 = np.prod([self.bottleneck_channels, *input_size], dtype=np.int64)
|
| 246 |
+
# conv2
|
| 247 |
+
output_size_conv2 = np.prod([self.bottleneck_channels, *size_after_stride], dtype=np.int64)
|
| 248 |
+
# conv3
|
| 249 |
+
output_size_conv3 = np.prod([self.output_channels, *size_after_stride], dtype=np.int64)
|
| 250 |
+
# skip conv (if applicable)
|
| 251 |
+
if (self.input_channels != self.output_channels) or any([i != j for i, j in zip(input_size, size_after_stride)]):
|
| 252 |
+
assert isinstance(self.skip, nn.Sequential)
|
| 253 |
+
output_size_skip = np.prod([self.output_channels, *size_after_stride], dtype=np.int64)
|
| 254 |
+
else:
|
| 255 |
+
assert not isinstance(self.skip, nn.Sequential)
|
| 256 |
+
output_size_skip = 0
|
| 257 |
+
return output_size_conv1 + output_size_conv2 + output_size_conv3 + output_size_skip
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
class StackedResidualBlocks(nn.Module):
|
| 261 |
+
def __init__(self,
|
| 262 |
+
n_blocks: int,
|
| 263 |
+
conv_op: Type[_ConvNd],
|
| 264 |
+
input_channels: int,
|
| 265 |
+
output_channels: Union[int, List[int], Tuple[int, ...]],
|
| 266 |
+
kernel_size: Union[int, List[int], Tuple[int, ...]],
|
| 267 |
+
initial_stride: Union[int, List[int], Tuple[int, ...]],
|
| 268 |
+
conv_bias: bool = False,
|
| 269 |
+
norm_op: Union[None, Type[nn.Module]] = None,
|
| 270 |
+
norm_op_kwargs: dict = None,
|
| 271 |
+
dropout_op: Union[None, Type[_DropoutNd]] = None,
|
| 272 |
+
dropout_op_kwargs: dict = None,
|
| 273 |
+
nonlin: Union[None, Type[torch.nn.Module]] = None,
|
| 274 |
+
nonlin_kwargs: dict = None,
|
| 275 |
+
block: Union[Type[BasicBlockD], Type[BottleneckD]] = BasicBlockD,
|
| 276 |
+
bottleneck_channels: Union[int, List[int], Tuple[int, ...]] = None,
|
| 277 |
+
stochastic_depth_p: float = 0.0,
|
| 278 |
+
squeeze_excitation: bool = False,
|
| 279 |
+
squeeze_excitation_reduction_ratio: float = 1. / 16
|
| 280 |
+
):
|
| 281 |
+
"""
|
| 282 |
+
Stack multiple instances of block.
|
| 283 |
+
|
| 284 |
+
:param n_blocks: number of residual blocks
|
| 285 |
+
:param conv_op: nn.ConvNd class
|
| 286 |
+
:param input_channels: only relevant for forst block in the sequence. This is the input number of features.
|
| 287 |
+
After the first block, the number of features in the main path to which the residuals are added is output_channels
|
| 288 |
+
:param output_channels: number of features in the main path to which the residuals are added (and also the
|
| 289 |
+
number of features of the output)
|
| 290 |
+
:param kernel_size: kernel size for all nxn (n!=1) convolutions. Default: 3x3
|
| 291 |
+
:param initial_stride: only affects the first block. All subsequent blocks have stride 1
|
| 292 |
+
:param conv_bias: usually False
|
| 293 |
+
:param norm_op: nn.BatchNormNd, InstanceNormNd etc
|
| 294 |
+
:param norm_op_kwargs: dictionary of kwargs. Leave empty ({}) for defaults
|
| 295 |
+
:param dropout_op: nn.DropoutNd, can be None for no dropout
|
| 296 |
+
:param dropout_op_kwargs:
|
| 297 |
+
:param nonlin:
|
| 298 |
+
:param nonlin_kwargs:
|
| 299 |
+
:param block: BasicBlockD or BottleneckD
|
| 300 |
+
:param bottleneck_channels: if block is BottleneckD then we need to know the number of bottleneck features.
|
| 301 |
+
Bottleneck will use first 1x1 conv to reduce input to bottleneck features, then run the nxn (see kernel_size)
|
| 302 |
+
conv on that (bottleneck -> bottleneck). Finally the output will be projected back to output_channels
|
| 303 |
+
(bottleneck -> output_channels) with the final 1x1 conv
|
| 304 |
+
:param stochastic_depth_p: probability of applying stochastic depth in residual blocks
|
| 305 |
+
:param squeeze_excitation: whether to apply squeeze and excitation or not
|
| 306 |
+
:param squeeze_excitation_reduction_ratio: ratio by how much squeeze and excitation should reduce channels
|
| 307 |
+
respective to number of out channels of respective block
|
| 308 |
+
"""
|
| 309 |
+
super().__init__()
|
| 310 |
+
assert n_blocks > 0, 'n_blocks must be > 0'
|
| 311 |
+
assert block in [BasicBlockD, BottleneckD], 'block must be BasicBlockD or BottleneckD'
|
| 312 |
+
if not isinstance(output_channels, (tuple, list)):
|
| 313 |
+
output_channels = [output_channels] * n_blocks
|
| 314 |
+
if not isinstance(bottleneck_channels, (tuple, list)):
|
| 315 |
+
bottleneck_channels = [bottleneck_channels] * n_blocks
|
| 316 |
+
|
| 317 |
+
if block == BasicBlockD:
|
| 318 |
+
blocks = nn.Sequential(
|
| 319 |
+
block(conv_op, input_channels, output_channels[0], kernel_size, initial_stride, conv_bias,
|
| 320 |
+
norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs, stochastic_depth_p,
|
| 321 |
+
squeeze_excitation, squeeze_excitation_reduction_ratio),
|
| 322 |
+
*[block(conv_op, output_channels[n - 1], output_channels[n], kernel_size, 1, conv_bias, norm_op,
|
| 323 |
+
norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs, stochastic_depth_p,
|
| 324 |
+
squeeze_excitation, squeeze_excitation_reduction_ratio) for n in range(1, n_blocks)]
|
| 325 |
+
)
|
| 326 |
+
else:
|
| 327 |
+
blocks = nn.Sequential(
|
| 328 |
+
block(conv_op, input_channels, bottleneck_channels[0], output_channels[0], kernel_size,
|
| 329 |
+
initial_stride, conv_bias, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs,
|
| 330 |
+
nonlin, nonlin_kwargs, stochastic_depth_p, squeeze_excitation, squeeze_excitation_reduction_ratio),
|
| 331 |
+
*[block(conv_op, output_channels[n - 1], bottleneck_channels[n], output_channels[n], kernel_size,
|
| 332 |
+
1, conv_bias, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs,
|
| 333 |
+
nonlin, nonlin_kwargs, stochastic_depth_p, squeeze_excitation,
|
| 334 |
+
squeeze_excitation_reduction_ratio) for n in range(1, n_blocks)]
|
| 335 |
+
)
|
| 336 |
+
self.blocks = blocks
|
| 337 |
+
self.initial_stride = maybe_convert_scalar_to_list(conv_op, initial_stride)
|
| 338 |
+
self.output_channels = output_channels[-1]
|
| 339 |
+
|
| 340 |
+
def forward(self, x):
|
| 341 |
+
return self.blocks(x)
|
| 342 |
+
|
| 343 |
+
def compute_conv_feature_map_size(self, input_size):
|
| 344 |
+
assert len(input_size) == len(self.initial_stride), "just give the image size without color/feature channels or " \
|
| 345 |
+
"batch channel. Do not give input_size=(b, c, x, y(, z)). " \
|
| 346 |
+
"Give input_size=(x, y(, z))!"
|
| 347 |
+
output = self.blocks[0].compute_conv_feature_map_size(input_size)
|
| 348 |
+
size_after_stride = [i // j for i, j in zip(input_size, self.initial_stride)]
|
| 349 |
+
for b in self.blocks[1:]:
|
| 350 |
+
output += b.compute_conv_feature_map_size(size_after_stride)
|
| 351 |
+
return output
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
if __name__ == '__main__':
|
| 355 |
+
data = torch.rand((1, 3, 40, 32))
|
| 356 |
+
|
| 357 |
+
stx = StackedResidualBlocks(2, nn.Conv2d, 24, (16, 16), (3, 3), (1, 2),
|
| 358 |
+
norm_op=nn.BatchNorm2d, nonlin=nn.ReLU, nonlin_kwargs={'inplace': True},
|
| 359 |
+
block=BottleneckD, bottleneck_channels=3)
|
| 360 |
+
model = nn.Sequential(ConvDropoutNormReLU(nn.Conv2d,
|
| 361 |
+
3, 24, 3, 1, True, nn.BatchNorm2d, {}, None, None, nn.LeakyReLU,
|
| 362 |
+
{'inplace': True}),
|
| 363 |
+
stx)
|
| 364 |
+
import hiddenlayer as hl
|
| 365 |
+
|
| 366 |
+
g = hl.build_graph(model, data,
|
| 367 |
+
transforms=None)
|
| 368 |
+
g.save("network_architecture.pdf")
|
| 369 |
+
del g
|
| 370 |
+
|
| 371 |
+
print(stx.compute_conv_feature_map_size((40, 32)))
|