spc819 commited on
Commit
7f3dfd7
·
verified ·
1 Parent(s): f249a24

Upload 69 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. CVPR25_TextSegFMData_with_class.json +0 -0
  3. config_CT.json +93 -0
  4. config_nonCT.json +13 -0
  5. data/__init__.py +0 -0
  6. data/default_resampling.py +208 -0
  7. data/resample_torch.py +162 -0
  8. data/resampling_test.py +593 -0
  9. environment.yml +211 -0
  10. evaluate/SurfaceDice.py +492 -0
  11. evaluate/__init__.py +0 -0
  12. evaluate/evaluator.py +379 -0
  13. evaluate/merge_after_evaluate.py +198 -0
  14. evaluate/metric.py +46 -0
  15. evaluate/params.py +153 -0
  16. inference_medals_nifti.py +1885 -0
  17. model/SwinUNETR.py +1116 -0
  18. model/__init__.py +0 -0
  19. model/base_bert.py +26 -0
  20. model/build_model.py +103 -0
  21. model/dynamic-network-architectures-main/.gitignore +113 -0
  22. model/dynamic-network-architectures-main/LICENCE +201 -0
  23. model/dynamic-network-architectures-main/README.md +25 -0
  24. model/dynamic-network-architectures-main/dynamic_network_architectures.egg-info/PKG-INFO +16 -0
  25. model/dynamic-network-architectures-main/dynamic_network_architectures.egg-info/SOURCES.txt +24 -0
  26. model/dynamic-network-architectures-main/dynamic_network_architectures.egg-info/dependency_links.txt +1 -0
  27. model/dynamic-network-architectures-main/dynamic_network_architectures.egg-info/not-zip-safe +1 -0
  28. model/dynamic-network-architectures-main/dynamic_network_architectures.egg-info/requires.txt +2 -0
  29. model/dynamic-network-architectures-main/dynamic_network_architectures.egg-info/top_level.txt +1 -0
  30. model/dynamic-network-architectures-main/dynamic_network_architectures/__init__.py +0 -0
  31. model/dynamic-network-architectures-main/dynamic_network_architectures/__pycache__/__init__.cpython-310.pyc +0 -0
  32. model/dynamic-network-architectures-main/dynamic_network_architectures/architectures/__init__.py +0 -0
  33. model/dynamic-network-architectures-main/dynamic_network_architectures/architectures/__pycache__/__init__.cpython-310.pyc +0 -0
  34. model/dynamic-network-architectures-main/dynamic_network_architectures/architectures/__pycache__/unet.cpython-310.pyc +0 -0
  35. model/dynamic-network-architectures-main/dynamic_network_architectures/architectures/resnet.py +236 -0
  36. model/dynamic-network-architectures-main/dynamic_network_architectures/architectures/unet.py +220 -0
  37. model/dynamic-network-architectures-main/dynamic_network_architectures/architectures/vgg.py +85 -0
  38. model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__init__.py +0 -0
  39. model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/__init__.cpython-310.pyc +0 -0
  40. model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/helper.cpython-310.pyc +0 -0
  41. model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/plain_conv_encoder.cpython-310.pyc +0 -0
  42. model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/regularization.cpython-310.pyc +0 -0
  43. model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/residual.cpython-310.pyc +0 -0
  44. model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/residual_encoders.cpython-310.pyc +0 -0
  45. model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/simple_conv_blocks.cpython-310.pyc +0 -0
  46. model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/unet_decoder.cpython-310.pyc +0 -0
  47. model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/helper.py +242 -0
  48. model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/plain_conv_encoder.py +105 -0
  49. model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/regularization.py +86 -0
  50. model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/residual.py +371 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ model/dynamic-network-architectures-main/imgs/Logos/HI_Logo.png filter=lfs diff=lfs merge=lfs -text
CVPR25_TextSegFMData_with_class.json ADDED
The diff for this file is too large to render. See raw diff
 
config_CT.json ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "texts_soft_tissue": [
3
+ "Aorta in whole body CT",
4
+ "gallbladder in whole body CT",
5
+ "left kidney in whole body CT",
6
+ "right kidney in whole body CT",
7
+ "liver in whole body CT",
8
+ "Pancreas in whole body CT",
9
+ "Spleen in whole body CT",
10
+ "stomach in whole body CT",
11
+ "Left adrenal gland in whole body CT",
12
+ "right adrenal gland in whole body CT",
13
+ "Bladder in whole body CT",
14
+ "Esophagus in whole body CT",
15
+ "Heart in whole body CT",
16
+ "Pulmonary vein in whole body CT",
17
+ "Brachiocephalic trunk in whole body CT",
18
+ "Right subclavian artery in whole body CT",
19
+ "Left subclavian artery in whole body CT",
20
+ "Right common carotid artery in whole body CT",
21
+ "Left common carotid artery in whole body CT",
22
+ "Left brachiocephalic vein in whole body CT",
23
+ "Right brachiocephalic vein in whole body CT",
24
+ "Left atrial appendage in whole body CT",
25
+ "Superior vena cava in whole body CT",
26
+ "Inferior vena cava in whole body CT",
27
+ "Portal vein and splenic vein in whole body CT",
28
+ "Left iliac artery in whole body CT",
29
+ "Right iliac artery in whole body CT",
30
+ "Left iliac vena in whole body CT",
31
+ "Right iliac vena in whole body CT",
32
+ "Spinal cord in whole body CT",
33
+ "Left gluteus Maximus in whole body CT",
34
+ "Right gluteus Maximus in whole body CT",
35
+ "Left gluteus Medius in whole body CT",
36
+ "Right gluteus Medius in whole body CT",
37
+ "Left gluteus Minimus in whole body CT",
38
+ "Right gluteus Minimus in whole body CT",
39
+ "Left autochthon in whole body CT",
40
+ "Right autochthon in whole body CT",
41
+ "Left iliopsoas in whole body CT",
42
+ "Right iliopsoas in whole body CT"
43
+ ],
44
+ "texts_bone": [
45
+ "Vertebrae C7 in whole body CT",
46
+ "Vertebrae C6 in whole body CT",
47
+ "Vertebrae C5 in whole body CT",
48
+ "Vertebrae C4 in whole body CT",
49
+ "Vertebrae C3 in whole body CT",
50
+ "Vertebrae C2 in whole body CT",
51
+ "Vertebrae C1 in whole body CT",
52
+ "Vertebrae T12 in whole body CT",
53
+ "Vertebrae T11 in whole body CT",
54
+ "Vertebrae T10 in whole body CT",
55
+ "Vertebrae T9 in whole body CT",
56
+ "Vertebrae T8 in whole body CT",
57
+ "Vertebrae T7 in whole body CT",
58
+ "Vertebrae T6 in whole body CT",
59
+ "Vertebrae T5 in whole body CT",
60
+ "Vertebrae T4 in whole body CT",
61
+ "Vertebrae T3 in whole body CT",
62
+ "Vertebrae T2 in whole body CT",
63
+ "Vertebrae T1 in whole body CT",
64
+ "Left humerus in whole body CT",
65
+ "Right humerus in whole body CT",
66
+ "Left clavicula in whole body CT",
67
+ "Right clavicula in whole body CT",
68
+ "Left femur in whole body CT",
69
+ "Right femur in whole body CT",
70
+ "Left hip in whole body CT",
71
+ "Right hip in whole body CT"
72
+ ],
73
+ "texts_lung": [
74
+ "Left lung in whole body CT",
75
+ "Right lung in whole body CT"
76
+ ],
77
+ "window_settings": {
78
+ "soft_tissue": {
79
+ "window_level": 40,
80
+ "window_width": 400
81
+ },
82
+ "bone": {
83
+ "window_level": 500,
84
+ "window_width": 1500
85
+ },
86
+ "lung": {
87
+ "window_level": -600,
88
+ "window_width": 1500
89
+ }
90
+ },
91
+ "modality": "CT",
92
+ "instance_label": 0
93
+ }
config_nonCT.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "texts": [
3
+ "Spleen in MRI"
4
+ ],
5
+ "normalization_settings": {
6
+ "percentile_lower": 0.5,
7
+ "percentile_upper": 99.5,
8
+ "preserve_zero": true
9
+ },
10
+ "modality": "MRI",
11
+ "instance_label": 0
12
+ }
13
+
data/__init__.py ADDED
File without changes
data/default_resampling.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from copy import deepcopy
3
+ from typing import Union, Tuple, List
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import sklearn
8
+ import torch
9
+ from batchgenerators.augmentations.utils import resize_segmentation
10
+ from scipy.ndimage import map_coordinates
11
+ from skimage.transform import resize
12
+
13
+ ANISO_THRESHOLD = 3 # determines when a sample is considered anisotropic (3 means that the spacing in the low
14
+ # resolution axis must be 3x as large as the next largest spacing)
15
+
16
+ def get_do_separate_z(spacing: Union[Tuple[float, ...], List[float], np.ndarray], anisotropy_threshold=ANISO_THRESHOLD):
17
+ do_separate_z = (np.max(spacing) / np.min(spacing)) > anisotropy_threshold
18
+ return do_separate_z
19
+
20
+
21
+ def get_lowres_axis(new_spacing: Union[Tuple[float, ...], List[float], np.ndarray]):
22
+ axis = np.where(max(new_spacing) / np.array(new_spacing) == 1)[0] # find which axis is anisotropic
23
+ return axis
24
+
25
+
26
+ def compute_new_shape(old_shape: Union[Tuple[int, ...], List[int], np.ndarray],
27
+ old_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
28
+ new_spacing: Union[Tuple[float, ...], List[float], np.ndarray]) -> np.ndarray:
29
+ assert len(old_spacing) == len(old_shape)
30
+ assert len(old_shape) == len(new_spacing)
31
+ new_shape = np.array([int(round(i / j * k)) for i, j, k in zip(old_spacing, new_spacing, old_shape)])
32
+ return new_shape
33
+
34
+
35
+
36
+ def determine_do_sep_z_and_axis(
37
+ force_separate_z: bool,
38
+ current_spacing,
39
+ new_spacing,
40
+ separate_z_anisotropy_threshold: float = ANISO_THRESHOLD) -> Tuple[bool, Union[int, None]]:
41
+ if force_separate_z is not None:
42
+ do_separate_z = force_separate_z
43
+ if force_separate_z:
44
+ axis = get_lowres_axis(current_spacing)
45
+ else:
46
+ axis = None
47
+ else:
48
+ if get_do_separate_z(current_spacing, separate_z_anisotropy_threshold):
49
+ do_separate_z = True
50
+ axis = get_lowres_axis(current_spacing)
51
+ elif get_do_separate_z(new_spacing, separate_z_anisotropy_threshold):
52
+ do_separate_z = True
53
+ axis = get_lowres_axis(new_spacing)
54
+ else:
55
+ do_separate_z = False
56
+ axis = None
57
+
58
+ if axis is not None:
59
+ if len(axis) == 3:
60
+ do_separate_z = False
61
+ axis = None
62
+ elif len(axis) == 2:
63
+ # this happens for spacings like (0.24, 1.25, 1.25) for example. In that case we do not want to resample
64
+ # separately in the out of plane axis
65
+ do_separate_z = False
66
+ axis = None
67
+ else:
68
+ axis = axis[0]
69
+ return do_separate_z, axis
70
+
71
+
72
+ def resample_data_or_seg_to_spacing(data: np.ndarray,
73
+ current_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
74
+ new_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
75
+ is_seg: bool = False,
76
+ order: int = 3, order_z: int = 0,
77
+ force_separate_z: Union[bool, None] = False,
78
+ separate_z_anisotropy_threshold: float = ANISO_THRESHOLD):
79
+ do_separate_z, axis = determine_do_sep_z_and_axis(force_separate_z, current_spacing, new_spacing,
80
+ separate_z_anisotropy_threshold)
81
+
82
+ if data is not None:
83
+ assert data.ndim == 4, "data must be c x y z"
84
+
85
+ shape = np.array(data.shape)
86
+ new_shape = compute_new_shape(shape[1:], current_spacing, new_spacing)
87
+
88
+ data_reshaped = resample_data_or_seg(data, new_shape, is_seg, axis, order, do_separate_z, order_z=order_z)
89
+ return data_reshaped
90
+
91
+
92
+ def resample_data_or_seg_to_shape(data: Union[torch.Tensor, np.ndarray],
93
+ new_shape: Union[Tuple[int, ...], List[int], np.ndarray],
94
+ current_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
95
+ new_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
96
+ is_seg: bool = False,
97
+ order: int = 3, order_z: int = 0,
98
+ force_separate_z: Union[bool, None] = False,
99
+ separate_z_anisotropy_threshold: float = ANISO_THRESHOLD):
100
+ """
101
+ needed for segmentation export. Stupid, I know
102
+ """
103
+ if isinstance(data, torch.Tensor):
104
+ data = data.numpy()
105
+
106
+ do_separate_z, axis = determine_do_sep_z_and_axis(force_separate_z, current_spacing, new_spacing,
107
+ separate_z_anisotropy_threshold)
108
+
109
+ if data is not None:
110
+ assert data.ndim == 4, "data must be c x y z"
111
+
112
+ data_reshaped = resample_data_or_seg(data, new_shape, is_seg, axis, order, do_separate_z, order_z=order_z)
113
+ return data_reshaped
114
+
115
+
116
+ def resample_data_or_seg(data: np.ndarray, new_shape: Union[Tuple[float, ...], List[float], np.ndarray],
117
+ is_seg: bool = False, axis: Union[None, int] = None, order: int = 3,
118
+ do_separate_z: bool = False, order_z: int = 0, dtype_out = None):
119
+ """
120
+ separate_z=True will resample with order 0 along z
121
+ :param data:
122
+ :param new_shape:
123
+ :param is_seg:
124
+ :param axis:
125
+ :param order:
126
+ :param do_separate_z:
127
+ :param order_z: only applies if do_separate_z is True
128
+ :return:
129
+ """
130
+ assert data.ndim == 4, "data must be (c, x, y, z)"
131
+ assert len(new_shape) == data.ndim - 1
132
+
133
+ if is_seg:
134
+ resize_fn = resize_segmentation
135
+ kwargs = OrderedDict()
136
+ else:
137
+ resize_fn = resize
138
+ kwargs = {'mode': 'edge', 'anti_aliasing': False}
139
+ shape = np.array(data[0].shape)
140
+ new_shape = np.array(new_shape)
141
+ if dtype_out is None:
142
+ dtype_out = data.dtype
143
+ reshaped_final = np.zeros((data.shape[0], *new_shape), dtype=dtype_out)
144
+ if np.any(shape != new_shape):
145
+ data = data.astype(float, copy=False)
146
+ if do_separate_z:
147
+ # print("separate z, order in z is", order_z, "order inplane is", order)
148
+ assert axis is not None, 'If do_separate_z, we need to know what axis is anisotropic'
149
+ if axis == 0:
150
+ new_shape_2d = new_shape[1:]
151
+ elif axis == 1:
152
+ new_shape_2d = new_shape[[0, 2]]
153
+ else:
154
+ new_shape_2d = new_shape[:-1]
155
+
156
+ for c in range(data.shape[0]):
157
+ tmp = deepcopy(new_shape)
158
+ tmp[axis] = shape[axis]
159
+ reshaped_here = np.zeros(tmp)
160
+ for slice_id in range(shape[axis]):
161
+ if axis == 0:
162
+ reshaped_here[slice_id] = resize_fn(data[c, slice_id], new_shape_2d, order, **kwargs)
163
+ elif axis == 1:
164
+ reshaped_here[:, slice_id] = resize_fn(data[c, :, slice_id], new_shape_2d, order, **kwargs)
165
+ else:
166
+ reshaped_here[:, :, slice_id] = resize_fn(data[c, :, :, slice_id], new_shape_2d, order, **kwargs)
167
+ if shape[axis] != new_shape[axis]:
168
+
169
+ # The following few lines are blatantly copied and modified from sklearn's resize()
170
+ rows, cols, dim = new_shape[0], new_shape[1], new_shape[2]
171
+ orig_rows, orig_cols, orig_dim = reshaped_here.shape
172
+
173
+ # align_corners=False
174
+ row_scale = float(orig_rows) / rows
175
+ col_scale = float(orig_cols) / cols
176
+ dim_scale = float(orig_dim) / dim
177
+
178
+ map_rows, map_cols, map_dims = np.mgrid[:rows, :cols, :dim]
179
+ map_rows = row_scale * (map_rows + 0.5) - 0.5
180
+ map_cols = col_scale * (map_cols + 0.5) - 0.5
181
+ map_dims = dim_scale * (map_dims + 0.5) - 0.5
182
+
183
+ coord_map = np.array([map_rows, map_cols, map_dims])
184
+ if not is_seg or order_z == 0:
185
+ reshaped_final[c] = map_coordinates(reshaped_here, coord_map, order=order_z, mode='nearest')[None]
186
+ else:
187
+ unique_labels = np.sort(pd.unique(reshaped_here.ravel())) # np.unique(reshaped_data)
188
+ for i, cl in enumerate(unique_labels):
189
+ reshaped_final[c][np.round(
190
+ map_coordinates((reshaped_here == cl).astype(float), coord_map, order=order_z,
191
+ mode='nearest')) > 0.5] = cl
192
+ else:
193
+ reshaped_final[c] = reshaped_here
194
+ else:
195
+ # print("no separate z, order", order)
196
+ for c in range(data.shape[0]):
197
+ reshaped_final[c] = resize_fn(data[c], new_shape, order, **kwargs)
198
+ return reshaped_final
199
+ else:
200
+ # print("no resampling necessary")
201
+ return data
202
+
203
+
204
+ if __name__ == '__main__':
205
+ input_array = np.random.random((1, 42, 231, 142))
206
+ output_shape = (52, 256, 256)
207
+ out = resample_data_or_seg(input_array, output_shape, is_seg=False, axis=3, order=1, order_z=0, do_separate_z=True)
208
+ print(out.shape, input_array.shape)
data/resample_torch.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+ from typing import Union, Tuple, List
3
+
4
+ import numpy as np
5
+ import torch
6
+ from einops import rearrange
7
+ from torch.nn import functional as F
8
+
9
+ from data.default_resampling import determine_do_sep_z_and_axis
10
+
11
+ ANISO_THRESHOLD = 3 # determines when a sample is considered anisotropic (3 means that the spacing in the low
12
+ # resolution axis must be 3x as large as the next largest spacing)
13
+
14
+ def resample_torch_simple(
15
+ data: Union[torch.Tensor, np.ndarray],
16
+ new_shape: Union[Tuple[int, ...], List[int], np.ndarray],
17
+ is_seg: bool = False,
18
+ num_threads: int = 4,
19
+ device: torch.device = torch.device('cpu'),
20
+ memefficient_seg_resampling: bool = False,
21
+ mode='linear'
22
+ ):
23
+ if mode == 'linear':
24
+ if data.ndim == 4:
25
+ torch_mode = 'trilinear'
26
+ elif data.ndim == 3:
27
+ torch_mode = 'bilinear'
28
+ else:
29
+ raise RuntimeError
30
+ else:
31
+ torch_mode = mode
32
+
33
+ if isinstance(new_shape, np.ndarray):
34
+ new_shape = [int(i) for i in new_shape]
35
+
36
+ if all([i == j for i, j in zip(new_shape, data.shape[1:])]):
37
+ return data
38
+ else:
39
+ n_threads = torch.get_num_threads()
40
+ torch.set_num_threads(num_threads)
41
+ new_shape = tuple(new_shape)
42
+ with torch.no_grad():
43
+
44
+ input_was_numpy = isinstance(data, np.ndarray)
45
+ if input_was_numpy:
46
+ data = torch.from_numpy(data).to(device)
47
+ else:
48
+ orig_device = deepcopy(data.device)
49
+ data = data.to(device)
50
+
51
+ if is_seg:
52
+ unique_values = torch.unique(data)
53
+ result_dtype = torch.int8 if max(unique_values) < 127 else torch.int16
54
+ result = torch.zeros((data.shape[0], *new_shape), dtype=result_dtype, device=device)
55
+ if not memefficient_seg_resampling:
56
+ # believe it or not, the implementation below is 3x as fast (at least on Liver CT and on CPU)
57
+ # Why? Because argmax is slow. The implementation below immediately sets most locations and only lets the
58
+ # uncertain ones be determined by argmax
59
+
60
+ # unique_values = torch.unique(data)
61
+ # result = torch.zeros((len(unique_values), data.shape[0], *new_shape), dtype=torch.float16)
62
+ # for i, u in enumerate(unique_values):
63
+ # result[i] = F.interpolate((data[None] == u).float() * 1000, new_shape, mode='trilinear', antialias=False)[0]
64
+ # result = unique_values[result.argmax(0)]
65
+
66
+ result_tmp = torch.zeros((len(unique_values), data.shape[0], *new_shape), dtype=torch.float16,
67
+ device=device)
68
+ scale_factor = 1000
69
+ done_mask = torch.zeros_like(result, dtype=torch.bool, device=device)
70
+ for i, u in enumerate(unique_values):
71
+ result_tmp[i] = \
72
+ F.interpolate((data[None] == u).float() * scale_factor, new_shape, mode=torch_mode,
73
+ antialias=False)[0]
74
+ mask = result_tmp[i] > (0.7 * scale_factor)
75
+ result[mask] = u.item()
76
+ done_mask |= mask
77
+ if not torch.all(done_mask):
78
+ # print('resolving argmax', torch.sum(~done_mask), "voxels to go")
79
+ result[~done_mask] = unique_values[result_tmp[:, ~done_mask].argmax(0)].to(result_dtype)
80
+ else:
81
+ for i, u in enumerate(unique_values):
82
+ if u == 0:
83
+ pass
84
+ result[F.interpolate((data[None] == u).float(), new_shape, mode=torch_mode, antialias=False)[
85
+ 0] > 0.5] = u
86
+ else:
87
+ result = F.interpolate(data[None].float(), new_shape, mode=torch_mode, antialias=False)[0]
88
+ if input_was_numpy:
89
+ result = result.cpu().numpy()
90
+ else:
91
+ result = result.to(orig_device)
92
+ torch.set_num_threads(n_threads)
93
+ return result
94
+
95
+
96
+ def resample_torch_fornnunet(
97
+ data: Union[torch.Tensor, np.ndarray],
98
+ new_shape: Union[Tuple[int, ...], List[int], np.ndarray],
99
+ current_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
100
+ new_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
101
+ is_seg: bool = False,
102
+ num_threads: int = 4,
103
+ device: torch.device = torch.device('cpu'),
104
+ memefficient_seg_resampling: bool = False,
105
+ force_separate_z: Union[bool, None] = None,
106
+ separate_z_anisotropy_threshold: float = ANISO_THRESHOLD,
107
+ mode='linear',
108
+ aniso_axis_mode='nearest-exact'
109
+ ):
110
+ """
111
+ data must be c, x, y, z
112
+ """
113
+ assert data.ndim == 4, "data must be c, x, y, z"
114
+ new_shape = [int(i) for i in new_shape]
115
+ orig_shape = data.shape
116
+
117
+ do_separate_z, axis = determine_do_sep_z_and_axis(force_separate_z, current_spacing, new_spacing,
118
+ separate_z_anisotropy_threshold)
119
+ # print('shape', data.shape, 'current_spacing', current_spacing, 'new_spacing', new_spacing, 'do_separate_z', do_separate_z, 'axis', axis)
120
+
121
+ if do_separate_z:
122
+ was_numpy = isinstance(data, np.ndarray)
123
+ if was_numpy:
124
+ data = torch.from_numpy(data)
125
+
126
+ if isinstance(axis, list):
127
+ assert len(axis) == 1
128
+ axis = axis[0]
129
+ else:
130
+ pass
131
+
132
+ tmp = "xyz"
133
+ axis_letter = tmp[axis]
134
+ others_int = [i for i in range(3) if i != axis]
135
+ others = [tmp[i] for i in others_int]
136
+
137
+ # reshape by overloading c channel
138
+ data = rearrange(data, f"c x y z -> (c {axis_letter}) {others[0]} {others[1]}")
139
+
140
+ # reshape in-plane
141
+ tmp_new_shape = [new_shape[i] for i in others_int]
142
+ data = resample_torch_simple(data, tmp_new_shape, is_seg=is_seg, num_threads=num_threads, device=device,
143
+ memefficient_seg_resampling=memefficient_seg_resampling, mode=mode)
144
+ data = rearrange(data, f"(c {axis_letter}) {others[0]} {others[1]} -> c x y z",
145
+ **{
146
+ axis_letter: orig_shape[axis + 1],
147
+ others[0]: tmp_new_shape[0],
148
+ others[1]: tmp_new_shape[1]
149
+ }
150
+ )
151
+ # reshape out of plane w/ nearest
152
+ data = resample_torch_simple(data, new_shape, is_seg=is_seg, num_threads=num_threads, device=device,
153
+ memefficient_seg_resampling=memefficient_seg_resampling, mode=aniso_axis_mode)
154
+ if was_numpy:
155
+ data = data.numpy()
156
+ return data
157
+ else:
158
+ return resample_torch_simple(data, new_shape, is_seg, num_threads, device, memefficient_seg_resampling)
159
+
160
+
161
+ if __name__ == '__main__':
162
+ torch.set_num_threads(16)
data/resampling_test.py ADDED
@@ -0,0 +1,593 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, Tuple, List
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from einops import rearrange
6
+ import time
7
+ from copy import deepcopy
8
+ from default_resampling import determine_do_sep_z_and_axis
9
+ import psutil
10
+ import nibabel as nib
11
+ import os
12
+ from pathlib import Path
13
+
14
+ ANISO_THRESHOLD = 3
15
+
16
+ def compute_new_shape(current_shape: Union[Tuple[int, ...], List[int], np.ndarray],
17
+ current_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
18
+ target_spacing: Union[Tuple[float, ...], List[float], np.ndarray]) -> List[int]:
19
+ """Compute new shape based on spacing ratios."""
20
+ current_shape = np.array(current_shape)
21
+ current_spacing = np.array(current_spacing)
22
+ target_spacing = np.array(target_spacing)
23
+ return [int(round(s * (cs / ts))) for s, cs, ts in zip(current_shape, current_spacing, target_spacing)]
24
+
25
+ def optimized_3d_resample(
26
+ data: Union[torch.Tensor, np.ndarray],
27
+ current_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
28
+ target_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
29
+ is_seg: bool = False,
30
+ device: torch.device = torch.device('cpu'),
31
+ num_threads: int = 8,
32
+ chunk_size: int = 64,
33
+ force_separate_z: Union[bool, None] = None,
34
+ separate_z_anisotropy_threshold: float = ANISO_THRESHOLD,
35
+ preserve_range: bool = True
36
+ ) -> Union[torch.Tensor, np.ndarray]:
37
+ """
38
+ Optimized 3D image resampling with adaptive interpolation and chunked processing.
39
+
40
+ Args:
41
+ data: Input 3D volume [C, D, H, W] or [D, H, W]
42
+ current_spacing: Current voxel spacing (z, y, x)
43
+ target_spacing: Target voxel spacing (z, y, x)
44
+ is_seg: Whether the input is a segmentation mask
45
+ device: Torch device for computation
46
+ num_threads: Number of threads for CPU operations
47
+ chunk_size: Size of chunks for large volume processing
48
+ force_separate_z: Force separate z resampling
49
+ separate_z_anisotropy_threshold: Threshold for anisotropic resampling
50
+ preserve_range: Preserve original value range for non-segmentation data
51
+
52
+ Returns:
53
+ Resampled 3D volume
54
+ """
55
+ print(f"\nStarting optimized_3d_resample with input shape: {data.shape}, is_seg: {is_seg}")
56
+ input_was_numpy = isinstance(data, np.ndarray)
57
+ if input_was_numpy:
58
+ data = torch.from_numpy(data).to(device)
59
+ else:
60
+ data = data.to(device)
61
+ print(f"Input converted to tensor on {device}, shape: {data.shape}")
62
+
63
+ if data.ndim == 3:
64
+ data = data.unsqueeze(0)
65
+ assert data.ndim == 4, "Data must be 3D or 4D (C, D, H, W)"
66
+
67
+ new_shape = compute_new_shape(data.shape[1:], current_spacing, target_spacing)
68
+ print(f"Computed new shape: {new_shape} from current_spacing: {current_spacing}, target_spacing: {target_spacing}")
69
+
70
+ if all(i == j for i, j in zip(new_shape, data.shape[1:])):
71
+ print("No resampling needed, shapes identical.")
72
+ return data.cpu().numpy() if input_was_numpy else data
73
+
74
+ mode = 'nearest' if is_seg else 'trilinear'
75
+ aniso_axis_mode = 'nearest-exact' if is_seg else 'linear'
76
+ print(f"Interpolation mode: {mode}, Anisotropic axis mode: {aniso_axis_mode}")
77
+
78
+ do_separate_z, axis = determine_do_sep_z_and_axis(force_separate_z, current_spacing,
79
+ target_spacing, separate_z_anisotropy_threshold)
80
+ print(f"Do separate Z: {do_separate_z}, Axis: {axis}")
81
+
82
+ if preserve_range and not is_seg:
83
+ v_min, v_max = data.min(), data.max()
84
+ print(f"Preserving range for non-segmentation data: min={v_min.item():.4f}, max={v_max.item():.4f}")
85
+
86
+ torch.set_num_threads(num_threads)
87
+ print(f"Set number of threads to {num_threads}")
88
+
89
+ start_time = time.time()
90
+ if do_separate_z:
91
+ tmp = "xyz"
92
+ axis_letter = tmp[axis]
93
+ others_int = [i for i in range(3) if i != axis]
94
+ others = [tmp[i] for i in others_int]
95
+ print(f"Separate Z resampling along axis {axis_letter}, others: {others}")
96
+
97
+ tmp_new_shape = [new_shape[i] for i in others_int]
98
+ print(f"First pass: Resampling to shape {tmp_new_shape} for axes {others}")
99
+ data = rearrange(data, f"c x y z -> (c {axis_letter}) {others[0]} {others[1]}")
100
+ print(f"Rearranged data shape: {data.shape}")
101
+ data = _chunked_resample(data, tmp_new_shape, mode, chunk_size, device, is_seg)
102
+ print(f"After first pass resampling, shape: {data.shape}")
103
+
104
+ data = rearrange(data, f"(c {axis_letter}) {others[0]} {others[1]} -> c x y z",
105
+ **{axis_letter: data.shape[1], others[0]: tmp_new_shape[0], others[1]: tmp_new_shape[1]})
106
+ print(f"Rearranged back to shape: {data.shape}")
107
+ data = _chunked_resample(data, new_shape, aniso_axis_mode, chunk_size, device, is_seg)
108
+ print(f"After second pass resampling, final shape: {data.shape}")
109
+ else:
110
+ print(f"Direct resampling to shape: {new_shape}")
111
+ data = _chunked_resample(data, new_shape, mode, chunk_size, device, is_seg)
112
+ print(f"After direct resampling, final shape: {data.shape}")
113
+ resample_time = time.time() - start_time
114
+ print(f"Resampling completed in {resample_time:.3f}s")
115
+
116
+ if is_seg:
117
+ unique_values = torch.unique(data)
118
+ result_dtype = torch.int8 if max(unique_values) < 127 else torch.int16
119
+ data = data.round().to(result_dtype)
120
+ print(f"Segmentation data rounded and converted to {result_dtype}, unique values: {unique_values.tolist()}")
121
+
122
+ if preserve_range and not is_seg:
123
+ data = torch.clamp(data, v_min, v_max)
124
+ print(f"Clamped data to original range: min={v_min.item():.4f}, max={v_max.item():.4f}")
125
+
126
+ output = data.cpu().numpy() if input_was_numpy else data
127
+ print(f"Output shape: {output.shape}, type: {type(output)}")
128
+ return output
129
+
130
+ def _chunked_resample(
131
+ volume: torch.Tensor,
132
+ target_shape: Tuple[int, ...],
133
+ mode: str,
134
+ chunk_size: int,
135
+ device: torch.device,
136
+ is_seg: bool
137
+ ) -> torch.Tensor:
138
+ """Chunked resampling for large volumes with adaptive chunk sizing."""
139
+ print(f"\nStarting _chunked_resample with input shape: {volume.shape}, target shape: {target_shape}")
140
+ C, D, H, W = volume.shape
141
+ tD, tH, tW = target_shape
142
+
143
+ # Adaptive chunk size based on available memory
144
+ if device.type == 'cpu':
145
+ available_memory = psutil.virtual_memory().available / 1024**2 # in MB
146
+ else:
147
+ total_memory = torch.cuda.get_device_properties(device).total_memory / 1024**2 # in MB
148
+ allocated_memory = torch.cuda.memory_allocated(device) / 1024**2
149
+ available_memory = total_memory - allocated_memory
150
+
151
+ mem_per_voxel = volume.element_size() * volume.nelement() / volume.numel()
152
+ target_voxel_count = C * tD * tH * tW
153
+ chunk_mem_ratio = 0.5 if device.type == 'cpu' else 0.3
154
+ adaptive_chunk_size = max(
155
+ 32,
156
+ min(chunk_size, int((available_memory * chunk_mem_ratio / mem_per_voxel / C) ** (1/3)))
157
+ )
158
+
159
+ # Early return for small volumes
160
+ if D * H * W <= 128**3:
161
+ with torch.cuda.amp.autocast(enabled=not is_seg):
162
+ start_time = time.time()
163
+ # Cast to float for interpolation if is_seg and mode is nearest
164
+ input_tensor = volume.float() if is_seg and mode == 'nearest' else volume
165
+ result = F.interpolate(
166
+ input_tensor.unsqueeze(0),
167
+ size=target_shape,
168
+ mode=mode,
169
+ align_corners=False if mode != 'nearest' else None
170
+ ).squeeze(0)
171
+ # Convert back to original dtype for segmentation
172
+ if is_seg:
173
+ result = result.round().to(volume.dtype)
174
+ # print(f"Direct interpolation completed in {time.time() - start_time:.3f}s, output shape: {result.shape}")
175
+ return result
176
+
177
+ result = torch.zeros((C, tD, tH, tW), device=device, dtype=volume.dtype)
178
+
179
+ out_chunk_size = max(1, int(adaptive_chunk_size * min(tD/D, tH/H, tW/W)))
180
+
181
+ for c in range(C):
182
+ for z in range(0, tD, out_chunk_size):
183
+ z_end = min(z + out_chunk_size, tD)
184
+ for y in range(0, tH, out_chunk_size):
185
+ y_end = min(y + out_chunk_size, tH)
186
+ for x in range(0, tW, out_chunk_size):
187
+ x_end = min(x + out_chunk_size, tW)
188
+
189
+ in_z = max(0, int(z * D / tD) - 1)
190
+ in_z_end = min(D, int(z_end * D / tD) + 2)
191
+ in_y = max(0, int(y * H / tH) - 1)
192
+ in_y_end = min(H, int(y_end * H / tH) + 2)
193
+ in_x = max(0, int(x * W / tW) - 1)
194
+ in_x_end = min(W, int(x_end * W / tW) + 2)
195
+
196
+ chunk = volume[c:c+1, in_z:in_z_end, in_y:in_y_end, in_x:in_x_end]
197
+ chunk_target = (z_end - z, y_end - y, x_end - x)
198
+
199
+ with torch.cuda.amp.autocast(enabled=not is_seg):
200
+ start_time = time.time()
201
+ # Cast to float for interpolation if is_seg and mode is nearest
202
+ input_chunk = chunk.float() if is_seg and mode == 'nearest' else chunk
203
+ resampled_chunk = F.interpolate(
204
+ input_chunk.unsqueeze(0),
205
+ size=chunk_target,
206
+ mode=mode,
207
+ align_corners=False if mode != 'nearest' else None
208
+ ).squeeze(0)
209
+ # Convert back to original dtype for segmentation
210
+ if is_seg:
211
+ resampled_chunk = resampled_chunk.round().to(volume.dtype)
212
+ # print(f"Chunk interpolation completed in {time.time() - start_time:.3f}s, shape: {resampled_chunk.shape}")
213
+
214
+ result[c, z:z_end, y:y_end, x:x_end] = resampled_chunk
215
+ del chunk, resampled_chunk
216
+ if device.type == 'cuda':
217
+ torch.cuda.empty_cache()
218
+
219
+ return result
220
+
221
+ def resample_torch_simple(
222
+ data: Union[torch.Tensor, np.ndarray],
223
+ new_shape: Union[Tuple[int, ...], List[int], np.ndarray],
224
+ is_seg: bool = False,
225
+ num_threads: int = 4,
226
+ device: torch.device = torch.device('cpu'),
227
+ memefficient_seg_resampling: bool = False,
228
+ mode: str = 'linear'
229
+ ) -> Union[torch.Tensor, np.ndarray]:
230
+ if mode == 'linear':
231
+ torch_mode = 'trilinear' if data.ndim == 4 else 'bilinear'
232
+ else:
233
+ torch_mode = mode
234
+
235
+ if isinstance(new_shape, np.ndarray):
236
+ new_shape = [int(i) for i in new_shape]
237
+
238
+ if all([i == j for i, j in zip(new_shape, data.shape[1:])]):
239
+ return data
240
+
241
+ n_threads = torch.get_num_threads()
242
+ torch.set_num_threads(num_threads)
243
+ new_shape = tuple(new_shape)
244
+ with torch.no_grad():
245
+ input_was_numpy = isinstance(data, np.ndarray)
246
+ if input_was_numpy:
247
+ data = torch.from_numpy(data).to(device)
248
+ else:
249
+ orig_device = deepcopy(data.device)
250
+ data = data.to(device)
251
+
252
+ if is_seg:
253
+ unique_values = torch.unique(data)
254
+ result_dtype = torch.int8 if max(unique_values) < 127 else torch.int16
255
+ result = torch.zeros((data.shape[0], *new_shape), dtype=result_dtype, device=device)
256
+ if not memefficient_seg_resampling:
257
+ result_tmp = torch.zeros((len(unique_values), data.shape[0], *new_shape), dtype=torch.float16,
258
+ device=device)
259
+ scale_factor = 1000
260
+ done_mask = torch.zeros_like(result, dtype=torch.bool, device=device)
261
+ for i, u in enumerate(unique_values):
262
+ result_tmp[i] = F.interpolate((data[None] == u).float() * scale_factor, new_shape, mode=torch_mode,
263
+ antialias=False)[0]
264
+ mask = result_tmp[i] > (0.7 * scale_factor)
265
+ result[mask] = u.item()
266
+ done_mask |= mask
267
+ if not torch.all(done_mask):
268
+ result[~done_mask] = unique_values[result_tmp[:, ~done_mask].argmax(0)].to(result_dtype)
269
+ else:
270
+ for i, u in enumerate(unique_values):
271
+ if u == 0:
272
+ continue
273
+ result[F.interpolate((data[None] == u).float(), new_shape, mode=torch_mode, antialias=False)[0] > 0.5] = u
274
+ else:
275
+ result = F.interpolate(data[None].float(), new_shape, mode=torch_mode, antialias=False)[0]
276
+
277
+ if input_was_numpy:
278
+ result = result.cpu().numpy()
279
+ else:
280
+ result = result.to(orig_device)
281
+
282
+ torch.set_num_threads(n_threads)
283
+ return result
284
+
285
+ def resample_torch_fornnunet(
286
+ data: Union[torch.Tensor, np.ndarray],
287
+ new_shape: Union[Tuple[int, ...], List[int], np.ndarray],
288
+ current_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
289
+ new_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
290
+ is_seg: bool = False,
291
+ num_threads: int = 4,
292
+ device: torch.device = torch.device('cpu'),
293
+ memefficient_seg_resampling: bool = False,
294
+ force_separate_z: Union[bool, None] = None,
295
+ separate_z_anisotropy_threshold: float = ANISO_THRESHOLD,
296
+ mode: str = 'linear',
297
+ aniso_axis_mode: str = 'nearest-exact'
298
+ ) -> Union[torch.Tensor, np.ndarray]:
299
+ assert data.ndim == 4, "data must be c, x, y, z"
300
+ new_shape = [int(i) for i in new_shape]
301
+ orig_shape = data.shape
302
+
303
+ do_separate_z, axis = determine_do_sep_z_and_axis(force_separate_z, current_spacing, new_spacing,
304
+ separate_z_anisotropy_threshold)
305
+
306
+ if do_separate_z:
307
+ was_numpy = isinstance(data, np.ndarray)
308
+ if was_numpy:
309
+ data = torch.from_numpy(data)
310
+
311
+ if isinstance(axis, list):
312
+ axis = axis[0]
313
+
314
+ tmp = "xyz"
315
+ axis_letter = tmp[axis]
316
+ others_int = [i for i in range(3) if i != axis]
317
+ others = [tmp[i] for i in others_int]
318
+
319
+ data = rearrange(data, f"c x y z -> (c {axis_letter}) {others[0]} {others[1]}")
320
+ tmp_new_shape = [new_shape[i] for i in others_int]
321
+ data = resample_torch_simple(data, tmp_new_shape, is_seg=is_seg, num_threads=num_threads, device=device,
322
+ memefficient_seg_resampling=memefficient_seg_resampling, mode=mode)
323
+ data = rearrange(data, f"(c {axis_letter}) {others[0]} {others[1]} -> c x y z",
324
+ **{axis_letter: orig_shape[axis + 1], others[0]: tmp_new_shape[0], others[1]: tmp_new_shape[1]})
325
+ data = resample_torch_simple(data, new_shape, is_seg=is_seg, num_threads=num_threads, device=device,
326
+ memefficient_seg_resampling=memefficient_seg_resampling, mode=aniso_axis_mode)
327
+ if was_numpy:
328
+ data = data.numpy()
329
+ return data
330
+ else:
331
+ return resample_torch_simple(data, new_shape, is_seg, num_threads, device, memefficient_seg_resampling)
332
+
333
+ def dice_score(pred: np.ndarray, true: np.ndarray) -> float:
334
+ """Compute Dice score for segmentation masks."""
335
+ pred = pred.flatten()
336
+ true = true.flatten()
337
+ intersection = np.sum(pred * true)
338
+ return (2. * intersection) / (np.sum(pred) + np.sum(true) + 1e-8)
339
+
340
+ # Placeholder for compute_new_shape if not provided
341
+ def compute_new_shape(original_shape, current_spacing, target_spacing):
342
+ """
343
+ Compute the new shape based on the spacing ratio.
344
+ original_shape: (z, y, x)
345
+ current_spacing: (z, y, x)
346
+ target_spacing: (z, y, x)
347
+ """
348
+ zoom_factors = [c / t for c, t in zip(current_spacing, target_spacing)]
349
+ new_shape = [int(round(s * z)) for s, z in zip(original_shape, zoom_factors)]
350
+ return tuple(new_shape)
351
+
352
+ # Function to save as NIfTI
353
+ def save_nii(array, spacing, output_path, is_seg=False):
354
+ """
355
+ Save numpy array as NIfTI file with specified spacing.
356
+ is_seg: If True, convert to int32 for segmentation masks.
357
+ """
358
+ # Convert torch tensor to numpy if necessary
359
+ if isinstance(array, torch.Tensor):
360
+ array = array.cpu().numpy()
361
+
362
+ # Convert data type for NIfTI compatibility
363
+ if is_seg:
364
+ array = array.astype(np.int32) # Convert segmentation to int32
365
+ else:
366
+ array = array.astype(np.float32) # Ensure image is float32
367
+
368
+ # Transpose to (X, Y, Z, C) for NIfTI
369
+ if array.ndim == 4:
370
+ array = array.transpose(2, 3, 1, 0) # From (C, Z, Y, X) to (X, Y, Z, C)
371
+ else:
372
+ array = array.transpose(2, 3, 1) # From (Z, Y, X) to (X, Y, Z)
373
+
374
+ # Create NIfTI image with affine based on spacing
375
+ affine = np.diag(list(spacing) + [1.0])
376
+ nii_img = nib.Nifti1Image(array, affine=affine)
377
+ nib.save(nii_img, output_path)
378
+ print(f"Saved: {output_path}")
379
+
380
+ # Main resampling function
381
+ def main():
382
+ torch.set_num_threads(4)
383
+ device = torch.device('cuda') #torch.device('cpu') # Force CPU as per provided code
384
+ print(f"\nRunning tests on device: {device}")
385
+
386
+ # Define paths
387
+ npz_file_path = "/media/shipc/hhd_8T/spc/code/CVPR2025_Text_guided_seg_submission/inputs/Microscopy_cremi_000_sc.npz"
388
+ gt_path = "/media/shipc/hhd_8T/spc/code/CVPR2025_Text_guided_seg_submission/gts/Microscopy_cremi_000_sc.npz"
389
+ output_dir = "/media/shipc/hhd_8T/spc/code/CVPR2025_Text_guided_seg_submission/workspace_teamx/outputs_test_resample"
390
+
391
+ # Ensure output directory exists
392
+ if not os.path.exists(output_dir):
393
+ os.makedirs(output_dir)
394
+
395
+ # Load input data
396
+ data = np.load(npz_file_path, allow_pickle=True)
397
+ img_array = data['imgs'] # Shape: (C, Z, Y, X) or (Z, Y, X)
398
+ img_spacing = data['spacing'] # (z, y, x)
399
+ img_spacing = [1.0, 1.0, 1.0] # Override as per provided code
400
+ gt_data = np.load(gt_path, allow_pickle=True)
401
+ gt_array = gt_data['gts'] # Shape: (C, Z, Y, X) or (Z, Y, X)
402
+
403
+ # Convert data types to PyTorch-compatible types
404
+ img_array = img_array.astype(np.float32) # Convert image to float32
405
+ gt_array = gt_array.astype(np.int32) # Convert segmentation mask to int32
406
+
407
+ # Ensure img_array and gt_array have channel dimension
408
+ if img_array.ndim == 3:
409
+ img_array = img_array[np.newaxis, ...] # Add channel dimension: (1, Z, Y, X)
410
+ if gt_array.ndim == 3:
411
+ gt_array = gt_array[np.newaxis, ...] # Add channel dimension: (1, Z, Y, X)
412
+
413
+ # Define target spacings to test
414
+ target_spacings = [
415
+ (1.2, 1.2, 1.2),
416
+ (1.5, 1.5, 1.5),
417
+ (2.0, 2.0, 2.0),
418
+ ]
419
+
420
+ # Original shape and spacing
421
+ original_shape = img_array.shape[1:] # (Z, Y, X)
422
+ current_spacing = img_spacing
423
+ print(f"\nOriginal image shape: {original_shape}, Current spacing (z,y,x): {current_spacing}")
424
+
425
+ for target_spacing in target_spacings:
426
+ print(f"\n=== Resampling to Target Spacing: {target_spacing} ===")
427
+
428
+ # Compute new shape
429
+ new_shape = compute_new_shape(original_shape, current_spacing, target_spacing)
430
+ print(f"Computed target shape: {new_shape}")
431
+
432
+ # === Image Resampling ===
433
+ print("\nResampling image...")
434
+
435
+ # Ground truth resampling
436
+ print("Computing ground truth with resample_torch_simple...")
437
+ start_time = time.time()
438
+ if device.type == 'cuda':
439
+ torch.cuda.synchronize() # Ensure GPU operations are complete
440
+ gt_img = resample_torch_simple(
441
+ img_array,
442
+ new_shape=new_shape,
443
+ is_seg=False,
444
+ num_threads=4,
445
+ device=device
446
+ )
447
+ if device.type == 'cuda':
448
+ torch.cuda.synchronize() # Ensure GPU operations are complete
449
+ gt_time = time.time() - start_time
450
+ output_path = os.path.join(output_dir, f"img_gt_spacing_{target_spacing[0]}_{target_spacing[1]}_{target_spacing[2]}.nii.gz")
451
+ print(f"Ground truth image shape: {gt_img.shape}, Time: {gt_time:.3f}s")
452
+ save_nii(gt_img, target_spacing, output_path, is_seg=False)
453
+
454
+ # Optimized resampling
455
+ print("Running optimized_3d_resample...")
456
+ start_time = time.time()
457
+ if device.type == 'cuda':
458
+ torch.cuda.synchronize()
459
+ mem_before = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2
460
+ resampled_img_opt = optimized_3d_resample(
461
+ img_array,
462
+ current_spacing,
463
+ target_spacing,
464
+ is_seg=False,
465
+ device=device,
466
+ num_threads=4,
467
+ chunk_size=64
468
+ )
469
+ if device.type == 'cuda':
470
+ torch.cuda.synchronize()
471
+
472
+ opt_time = time.time() - start_time
473
+ mem_after = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2
474
+ opt_mae = np.mean(np.abs(resampled_img_opt - gt_img))
475
+ output_path = os.path.join(output_dir, f"img_opt_spacing_{target_spacing[0]}_{target_spacing[1]}_{target_spacing[2]}.nii.gz")
476
+ print(f"Optimized image shape: {resampled_img_opt.shape}, Time: {opt_time:.3f}s, "
477
+ f"Memory used: {mem_after - mem_before:.2f} MB, MAE: {opt_mae:.6f}")
478
+ save_nii(resampled_img_opt, target_spacing, output_path, is_seg=False)
479
+
480
+ # Original resampling
481
+ print("Running resample_torch_fornnunet...")
482
+ start_time = time.time()
483
+ if device.type == 'cuda':
484
+ torch.cuda.synchronize()
485
+ mem_before = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2
486
+ resampled_img_orig = resample_torch_fornnunet(
487
+ img_array,
488
+ new_shape,
489
+ current_spacing,
490
+ target_spacing,
491
+ is_seg=False,
492
+ num_threads=4,
493
+ device=device
494
+ )
495
+ if device.type == 'cuda':
496
+ torch.cuda.synchronize()
497
+ orig_time = time.time() - start_time
498
+ mem_after = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2
499
+ orig_mae = np.mean(np.abs(resampled_img_orig - gt_img))
500
+ output_path = os.path.join(output_dir, f"img_orig_spacing_{target_spacing[0]}_{target_spacing[1]}_{target_spacing[2]}.nii.gz")
501
+ print(f"Original image shape: {resampled_img_orig.shape}, Time: {orig_time:.3f}s, "
502
+ f"Memory used: {mem_after - mem_before:.2f} MB, MAE: {orig_mae:.6f}")
503
+ save_nii(resampled_img_orig, target_spacing, output_path, is_seg=False)
504
+
505
+ # === Segmentation Mask Resampling ===
506
+ print("\nResampling segmentation mask...")
507
+
508
+ # Ground truth resampling
509
+ print("Computing ground truth with resample_torch_simple...")
510
+ start_time = time.time()
511
+ if device.type == 'cuda':
512
+ torch.cuda.synchronize()
513
+ gt_seg = resample_torch_simple(
514
+ gt_array,
515
+ new_shape=new_shape,
516
+ is_seg=True,
517
+ num_threads=4,
518
+ device=device
519
+ )
520
+ if device.type == 'cuda':
521
+ torch.cuda.synchronize()
522
+ gt_seg_time = time.time() - start_time
523
+ output_path = os.path.join(output_dir, f"seg_gt_spacing_{target_spacing[0]}_{target_spacing[1]}_{target_spacing[2]}.nii.gz")
524
+ print(f"Ground truth segmentation shape: {gt_seg.shape}, Time: {gt_seg_time:.3f}s")
525
+ save_nii(gt_seg, target_spacing, output_path, is_seg=True)
526
+
527
+ # Optimized resampling
528
+ print("Running optimized_3d_resample for segmentation...")
529
+ start_time = time.time()
530
+ if device.type == 'cuda':
531
+ torch.cuda.synchronize()
532
+ mem_before = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2
533
+ resampled_seg_opt = optimized_3d_resample(
534
+ gt_array,
535
+ current_spacing,
536
+ target_spacing,
537
+ is_seg=True,
538
+ device=device,
539
+ num_threads=4,
540
+ chunk_size=64
541
+ )
542
+ if device.type == 'cuda':
543
+ torch.cuda.synchronize()
544
+
545
+ opt_seg_time = time.time() - start_time
546
+ mem_after = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2
547
+ opt_dice = dice_score(resampled_seg_opt, gt_seg)
548
+ output_path = os.path.join(output_dir, f"seg_opt_spacing_{target_spacing[0]}_{target_spacing[1]}_{target_spacing[2]}.nii.gz")
549
+ print(f"Optimized segmentation shape: {resampled_seg_opt.shape}, Time: {opt_seg_time:.3f}s, "
550
+ f"Memory used: {mem_after - mem_before:.2f} MB, Dice: {opt_dice:.6f}")
551
+ save_nii(resampled_seg_opt, target_spacing, output_path, is_seg=True)
552
+
553
+ # Original resampling
554
+ print("Running resample_torch_fornnunet for segmentation...")
555
+ start_time = time.time()
556
+ if device.type == 'cuda':
557
+ torch.cuda.synchronize()
558
+ mem_before = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2
559
+ resampled_seg_orig = resample_torch_fornnunet(
560
+ gt_array,
561
+ new_shape,
562
+ current_spacing,
563
+ target_spacing,
564
+ is_seg=True,
565
+ num_threads=4,
566
+ device=device
567
+ )
568
+ if device.type == 'cuda':
569
+ torch.cuda.synchronize()
570
+
571
+ orig_seg_time = time.time() - start_time
572
+ mem_after = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2
573
+ orig_dice = dice_score(resampled_seg_orig, gt_seg)
574
+ output_path = os.path.join(output_dir, f"seg_orig_spacing_{target_spacing[0]}_{target_spacing[1]}_{target_spacing[2]}.nii.gz")
575
+ print(f"Original segmentation shape: {resampled_seg_orig.shape}, Time: {orig_seg_time:.3f}s, "
576
+ f"Memory used: {mem_after - mem_before:.2f} MB, Dice: {orig_dice:.6f}")
577
+ save_nii(resampled_seg_orig, target_spacing, output_path, is_seg=True)
578
+
579
+ # Summary
580
+ print(f"\n=== Summary for Target Spacing: {target_spacing} ===")
581
+ print("Image Resampling Metrics:")
582
+ print(f"Optimized - Shape: {resampled_img_opt.shape}, Time: {opt_time:.3f}s, MAE: {opt_mae:.6f}")
583
+ print(f"Original - Shape: {resampled_img_orig.shape}, Time: {orig_time:.3f}s, MAE: {orig_mae:.6f}")
584
+ print(f"Time Improvement: {(orig_time - opt_time) / orig_time * 100:.2f}%")
585
+ print(f"MAE Improvement: {(orig_mae - opt_mae) / orig_mae * 100:.2f}%")
586
+ print("Segmentation Mask Resampling Metrics:")
587
+ print(f"Optimized - Shape: {resampled_seg_opt.shape}, Time: {opt_seg_time:.3f}s, Dice: {opt_dice:.6f}")
588
+ print(f"Original - Shape: {resampled_seg_orig.shape}, Time: {orig_seg_time:.3f}s, Dice: {orig_dice:.6f}")
589
+ print(f"Time Improvement: {(orig_seg_time - opt_seg_time) / orig_seg_time * 100:.2f}%")
590
+ print(f"Dice Improvement: {(opt_dice - orig_dice) / orig_dice * 100:.2f}%")
591
+
592
+ if __name__ == '__main__':
593
+ main()
environment.yml ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: medals_local_test
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ - defaults
6
+ dependencies:
7
+ - _libgcc_mutex=0.1=main
8
+ - _openmp_mutex=5.1=1_gnu
9
+ - aom=3.12.1=h7934f7d_0
10
+ - blas=1.0=mkl
11
+ - brotlicffi=1.2.0.0=py310h7354ed3_0
12
+ - bzip2=1.0.8=h5eee18b_6
13
+ - ca-certificates=2025.12.2=h06a4308_0
14
+ - cairo=1.18.4=h44eff21_0
15
+ - certifi=2025.11.12=py310h06a4308_0
16
+ - cffi=2.0.0=py310h4eded50_1
17
+ - charset-normalizer=3.4.4=py310h06a4308_0
18
+ - cuda-cudart=12.1.105=0
19
+ - cuda-cupti=12.1.105=0
20
+ - cuda-libraries=12.1.0=0
21
+ - cuda-nvrtc=12.1.105=0
22
+ - cuda-nvtx=12.1.105=0
23
+ - cuda-opencl=12.9.19=0
24
+ - cuda-runtime=12.1.0=0
25
+ - cuda-version=12.9=3
26
+ - dav1d=1.2.1=h5eee18b_0
27
+ - expat=2.7.3=h7354ed3_4
28
+ - ffmpeg=6.1.1=hecf7045_5
29
+ - filelock=3.20.0=py310h06a4308_0
30
+ - fontconfig=2.15.0=h2c49b7f_0
31
+ - freetype=2.13.3=h4a9f257_0
32
+ - fribidi=1.0.10=h7b6447c_0
33
+ - giflib=5.2.2=h5eee18b_0
34
+ - gmp=6.3.0=h6a678d5_0
35
+ - gmpy2=2.2.2=py310ha78e65c_0
36
+ - graphite2=1.3.14=h295c915_1
37
+ - harfbuzz=10.2.0=hdfddeaa_1
38
+ - icu=73.1=h6a678d5_0
39
+ - idna=3.11=py310h06a4308_0
40
+ - intel-openmp=2025.0.0=h06a4308_1171
41
+ - jinja2=3.1.6=py310h06a4308_0
42
+ - jpeg=9f=h5ce9db8_0
43
+ - lame=3.100=h7b6447c_0
44
+ - lcms2=2.17=heab6991_0
45
+ - ld_impl_linux-64=2.44=h153f514_2
46
+ - leptonica=1.82.0=hfdeec58_3
47
+ - lerc=4.0.0=h6a678d5_0
48
+ - libarchive=3.8.2=h3ec8f01_0
49
+ - libavif=1.3.0=h3539ee5_0
50
+ - libcublas=12.1.0.26=0
51
+ - libcufft=11.0.2.4=0
52
+ - libcufile=1.14.1.1=4
53
+ - libcurand=10.3.10.19=0
54
+ - libcusolver=11.4.4.55=0
55
+ - libcusparse=12.0.2.55=0
56
+ - libdeflate=1.22=h5eee18b_0
57
+ - libexpat=2.7.3=h7354ed3_4
58
+ - libffi=3.4.4=h6a678d5_1
59
+ - libgcc=15.2.0=h69a1729_7
60
+ - libgcc-ng=15.2.0=h166f726_7
61
+ - libglib=2.84.4=h77a78f3_0
62
+ - libgomp=15.2.0=h4751f2c_7
63
+ - libhwloc=2.12.1=default_hf1bbc79_1000
64
+ - libiconv=1.16=h5eee18b_3
65
+ - libjpeg-turbo=2.0.0=h9bf148f_0
66
+ - libnpp=12.0.2.50=0
67
+ - libnsl=2.0.0=h5eee18b_0
68
+ - libnvjitlink=12.1.105=0
69
+ - libnvjpeg=12.1.1.14=0
70
+ - libogg=1.3.5=h27cfd23_1
71
+ - libopenjpeg=2.5.4=hee96239_1
72
+ - libopus=1.3.1=h5eee18b_1
73
+ - libpng=1.6.50=h2ed474d_0
74
+ - libstdcxx=15.2.0=h39759b7_7
75
+ - libstdcxx-ng=15.2.0=hc03a8fd_7
76
+ - libtheora=1.2.0=h32ad74f_1
77
+ - libtiff=4.7.1=h029b1ac_0
78
+ - libuuid=1.41.5=h5eee18b_0
79
+ - libvorbis=1.3.7=h7b6447c_0
80
+ - libvpx=1.15.2=h4cb591d_0
81
+ - libwebp=1.6.0=h089d785_0
82
+ - libwebp-base=1.6.0=hb7bb969_0
83
+ - libxcb=1.17.0=h9b100fa_0
84
+ - libxml2=2.13.9=h2c43086_0
85
+ - libzlib=1.3.1=hb25bd0a_0
86
+ - llvm-openmp=14.0.6=h9e868ea_0
87
+ - lz4-c=1.9.4=h6a678d5_1
88
+ - markupsafe=3.0.2=py310h5eee18b_0
89
+ - mkl=2025.0.0=hacee8c2_941
90
+ - mkl-service=2.5.2=py310hacdc0fc_0
91
+ - mkl_fft=2.1.1=py310h8fe796d_0
92
+ - mkl_random=1.3.0=py310h505adc9_0
93
+ - mpc=1.3.1=h5eee18b_0
94
+ - mpfr=4.2.1=h5eee18b_0
95
+ - mpmath=1.3.0=py310h06a4308_0
96
+ - ncurses=6.5=h7934f7d_0
97
+ - networkx=3.4.2=py310h06a4308_0
98
+ - ocl-icd=2.3.3=h47b2149_0
99
+ - opencl-headers=2025.07.22=hfb20e49_0
100
+ - openh264=2.6.0=he621ea3_0
101
+ - openjpeg=2.5.4=h4e0627c_1
102
+ - openssl=3.0.18=hd6dcaed_0
103
+ - pcre2=10.46=hf426167_0
104
+ - pillow=12.0.0=py310h3b88751_1
105
+ - pip=25.3=pyhc872135_0
106
+ - pixman=0.46.4=h7934f7d_0
107
+ - pthread-stubs=0.3=h0ce48e5_1
108
+ - pycparser=2.23=py310h06a4308_0
109
+ - pysocks=1.7.1=py310h06a4308_1
110
+ - python=3.10.19=h6fa692b_0
111
+ - pytorch-cuda=12.1=ha16c6d3_6
112
+ - pytorch-mutex=1.0=cuda
113
+ - pyyaml=6.0.3=py310h591646f_0
114
+ - readline=8.3=hc2a1206_0
115
+ - requests=2.32.5=py310h06a4308_1
116
+ - setuptools=80.9.0=py310h06a4308_0
117
+ - sqlite=3.51.0=h2a70700_0
118
+ - sympy=1.14.0=py310h06a4308_1
119
+ - tbb=2022.3.0=h698db13_0
120
+ - tbb-devel=2022.3.0=h698db13_0
121
+ - tesseract=5.2.0=hb0d2e87_3
122
+ - tk=8.6.15=h54e0aa7_0
123
+ - typing_extensions=4.15.0=py310h06a4308_0
124
+ - urllib3=2.6.1=py310h06a4308_0
125
+ - wheel=0.45.1=py310h06a4308_0
126
+ - xorg-libx11=1.8.12=h9b100fa_1
127
+ - xorg-libxau=1.0.12=h9b100fa_0
128
+ - xorg-libxdmcp=1.1.5=h9b100fa_0
129
+ - xorg-libxext=1.3.6=h9b100fa_0
130
+ - xorg-libxrender=0.9.12=h9b100fa_0
131
+ - xorg-xorgproto=2024.1=h5eee18b_1
132
+ - xz=5.6.4=h5eee18b_1
133
+ - yaml=0.2.5=h7b6447c_0
134
+ - zlib=1.3.1=hb25bd0a_0
135
+ - zstd=1.5.7=h11fc155_0
136
+ - pip:
137
+ - acvl-utils==0.2.5
138
+ - argparse==1.4.0
139
+ - batchgenerators==0.25.1
140
+ - blosc2==3.12.2
141
+ - connected-components-3d==3.26.1
142
+ - contourpy==1.3.2
143
+ - cycler==0.12.1
144
+ - dicom2nifti==2.6.2
145
+ - dynamic-network-architectures==0.2
146
+ - einops==0.8.1
147
+ - fonttools==4.61.1
148
+ - fsspec==2025.12.0
149
+ - future==1.0.0
150
+ - hf-xet==1.2.0
151
+ - huggingface-hub==0.36.0
152
+ - imagecodecs==2025.3.30
153
+ - imageio==2.37.2
154
+ - importlib-resources==6.5.2
155
+ - joblib==1.5.3
156
+ - kiwisolver==1.4.9
157
+ - lazy-loader==0.4
158
+ - linecache2==1.0.0
159
+ - matplotlib==3.10.8
160
+ - monai==1.4.0
161
+ - msgpack==1.1.2
162
+ - ndindex==1.10.1
163
+ - nibabel==5.3.2
164
+ - nnunetv2==2.4.1
165
+ - numexpr==2.14.1
166
+ - numpy==1.26.4
167
+ - nvidia-cublas-cu12==12.1.3.1
168
+ - nvidia-cuda-cupti-cu12==12.1.105
169
+ - nvidia-cuda-nvrtc-cu12==12.1.105
170
+ - nvidia-cuda-runtime-cu12==12.1.105
171
+ - nvidia-cudnn-cu12==8.9.2.26
172
+ - nvidia-cufft-cu12==11.0.2.54
173
+ - nvidia-curand-cu12==10.3.2.106
174
+ - nvidia-cusolver-cu12==11.4.5.107
175
+ - nvidia-cusparse-cu12==12.1.0.106
176
+ - nvidia-nccl-cu12==2.19.3
177
+ - nvidia-nvjitlink-cu12==12.9.86
178
+ - nvidia-nvtx-cu12==12.1.105
179
+ - packaging==25.0
180
+ - pandas==2.3.3
181
+ - platformdirs==4.5.1
182
+ - positional-encodings==6.0.3
183
+ - py-cpuinfo==9.0.0
184
+ - pydicom==3.0.1
185
+ - pyparsing==3.2.5
186
+ - python-dateutil==2.9.0.post0
187
+ - python-gdcm==3.2.2
188
+ - python-graphviz==0.21
189
+ - pytz==2025.2
190
+ - regex==2025.11.3
191
+ - safetensors==0.7.0
192
+ - scikit-image==0.25.2
193
+ - scikit-learn==1.7.2
194
+ - scipy==1.15.3
195
+ - seaborn==0.13.2
196
+ - simpleitk==2.5.3
197
+ - six==1.17.0
198
+ - threadpoolctl==3.6.0
199
+ - tifffile==2025.5.10
200
+ - tokenizers==0.21.4
201
+ - torch==2.2.0+cu121
202
+ - torchaudio==2.2.0+cu121
203
+ - torchvision==0.17.0+cu121
204
+ - tqdm==4.67.1
205
+ - traceback2==1.4.0
206
+ - transformers==4.51.3
207
+ - triton==2.2.0
208
+ - tzdata==2025.3
209
+ - unittest2==1.1.0
210
+ - yacs==0.1.8
211
+ prefix: /yinghepool/shipengcheng/.conda/envs/medals_local_test
evaluate/SurfaceDice.py ADDED
@@ -0,0 +1,492 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import scipy.ndimage
3
+
4
+ # neighbour_code_to_normals is a lookup table.
5
+ # For every binary neighbour code
6
+ # (2x2x2 neighbourhood = 8 neighbours = 8 bits = 256 codes)
7
+ # it contains the surface normals of the triangles (called "surfel" for
8
+ # "surface element" in the following). The length of the normal
9
+ # vector encodes the surfel area.
10
+ #
11
+ # created by compute_surface_area_lookup_table.ipynb using the
12
+ # marching_cube algorithm, see e.g. https://en.wikipedia.org/wiki/Marching_cubes
13
+ # credit to: http://medicaldecathlon.com/files/Surface_distance_based_measures.ipynb
14
+ neighbour_code_to_normals = [
15
+ [[0,0,0]],
16
+ [[0.125,0.125,0.125]],
17
+ [[-0.125,-0.125,0.125]],
18
+ [[-0.25,-0.25,0.0],[0.25,0.25,-0.0]],
19
+ [[0.125,-0.125,0.125]],
20
+ [[-0.25,-0.0,-0.25],[0.25,0.0,0.25]],
21
+ [[0.125,-0.125,0.125],[-0.125,-0.125,0.125]],
22
+ [[0.5,0.0,-0.0],[0.25,0.25,0.25],[0.125,0.125,0.125]],
23
+ [[-0.125,0.125,0.125]],
24
+ [[0.125,0.125,0.125],[-0.125,0.125,0.125]],
25
+ [[-0.25,0.0,0.25],[-0.25,0.0,0.25]],
26
+ [[0.5,0.0,0.0],[-0.25,-0.25,0.25],[-0.125,-0.125,0.125]],
27
+ [[0.25,-0.25,0.0],[0.25,-0.25,0.0]],
28
+ [[0.5,0.0,0.0],[0.25,-0.25,0.25],[-0.125,0.125,-0.125]],
29
+ [[-0.5,0.0,0.0],[-0.25,0.25,0.25],[-0.125,0.125,0.125]],
30
+ [[0.5,0.0,0.0],[0.5,0.0,0.0]],
31
+ [[0.125,-0.125,-0.125]],
32
+ [[0.0,-0.25,-0.25],[0.0,0.25,0.25]],
33
+ [[-0.125,-0.125,0.125],[0.125,-0.125,-0.125]],
34
+ [[0.0,-0.5,0.0],[0.25,0.25,0.25],[0.125,0.125,0.125]],
35
+ [[0.125,-0.125,0.125],[0.125,-0.125,-0.125]],
36
+ [[0.0,0.0,-0.5],[0.25,0.25,0.25],[-0.125,-0.125,-0.125]],
37
+ [[-0.125,-0.125,0.125],[0.125,-0.125,0.125],[0.125,-0.125,-0.125]],
38
+ [[-0.125,-0.125,-0.125],[-0.25,-0.25,-0.25],[0.25,0.25,0.25],[0.125,0.125,0.125]],
39
+ [[-0.125,0.125,0.125],[0.125,-0.125,-0.125]],
40
+ [[0.0,-0.25,-0.25],[0.0,0.25,0.25],[-0.125,0.125,0.125]],
41
+ [[-0.25,0.0,0.25],[-0.25,0.0,0.25],[0.125,-0.125,-0.125]],
42
+ [[0.125,0.125,0.125],[0.375,0.375,0.375],[0.0,-0.25,0.25],[-0.25,0.0,0.25]],
43
+ [[0.125,-0.125,-0.125],[0.25,-0.25,0.0],[0.25,-0.25,0.0]],
44
+ [[0.375,0.375,0.375],[0.0,0.25,-0.25],[-0.125,-0.125,-0.125],[-0.25,0.25,0.0]],
45
+ [[-0.5,0.0,0.0],[-0.125,-0.125,-0.125],[-0.25,-0.25,-0.25],[0.125,0.125,0.125]],
46
+ [[-0.5,0.0,0.0],[-0.125,-0.125,-0.125],[-0.25,-0.25,-0.25]],
47
+ [[0.125,-0.125,0.125]],
48
+ [[0.125,0.125,0.125],[0.125,-0.125,0.125]],
49
+ [[0.0,-0.25,0.25],[0.0,0.25,-0.25]],
50
+ [[0.0,-0.5,0.0],[0.125,0.125,-0.125],[0.25,0.25,-0.25]],
51
+ [[0.125,-0.125,0.125],[0.125,-0.125,0.125]],
52
+ [[0.125,-0.125,0.125],[-0.25,-0.0,-0.25],[0.25,0.0,0.25]],
53
+ [[0.0,-0.25,0.25],[0.0,0.25,-0.25],[0.125,-0.125,0.125]],
54
+ [[-0.375,-0.375,0.375],[-0.0,0.25,0.25],[0.125,0.125,-0.125],[-0.25,-0.0,-0.25]],
55
+ [[-0.125,0.125,0.125],[0.125,-0.125,0.125]],
56
+ [[0.125,0.125,0.125],[0.125,-0.125,0.125],[-0.125,0.125,0.125]],
57
+ [[-0.0,0.0,0.5],[-0.25,-0.25,0.25],[-0.125,-0.125,0.125]],
58
+ [[0.25,0.25,-0.25],[0.25,0.25,-0.25],[0.125,0.125,-0.125],[-0.125,-0.125,0.125]],
59
+ [[0.125,-0.125,0.125],[0.25,-0.25,0.0],[0.25,-0.25,0.0]],
60
+ [[0.5,0.0,0.0],[0.25,-0.25,0.25],[-0.125,0.125,-0.125],[0.125,-0.125,0.125]],
61
+ [[0.0,0.25,-0.25],[0.375,-0.375,-0.375],[-0.125,0.125,0.125],[0.25,0.25,0.0]],
62
+ [[-0.5,0.0,0.0],[-0.25,-0.25,0.25],[-0.125,-0.125,0.125]],
63
+ [[0.25,-0.25,0.0],[-0.25,0.25,0.0]],
64
+ [[0.0,0.5,0.0],[-0.25,0.25,0.25],[0.125,-0.125,-0.125]],
65
+ [[0.0,0.5,0.0],[0.125,-0.125,0.125],[-0.25,0.25,-0.25]],
66
+ [[0.0,0.5,0.0],[0.0,-0.5,0.0]],
67
+ [[0.25,-0.25,0.0],[-0.25,0.25,0.0],[0.125,-0.125,0.125]],
68
+ [[-0.375,-0.375,-0.375],[-0.25,0.0,0.25],[-0.125,-0.125,-0.125],[-0.25,0.25,0.0]],
69
+ [[0.125,0.125,0.125],[0.0,-0.5,0.0],[-0.25,-0.25,-0.25],[-0.125,-0.125,-0.125]],
70
+ [[0.0,-0.5,0.0],[-0.25,-0.25,-0.25],[-0.125,-0.125,-0.125]],
71
+ [[-0.125,0.125,0.125],[0.25,-0.25,0.0],[-0.25,0.25,0.0]],
72
+ [[0.0,0.5,0.0],[0.25,0.25,-0.25],[-0.125,-0.125,0.125],[-0.125,-0.125,0.125]],
73
+ [[-0.375,0.375,-0.375],[-0.25,-0.25,0.0],[-0.125,0.125,-0.125],[-0.25,0.0,0.25]],
74
+ [[0.0,0.5,0.0],[0.25,0.25,-0.25],[-0.125,-0.125,0.125]],
75
+ [[0.25,-0.25,0.0],[-0.25,0.25,0.0],[0.25,-0.25,0.0],[0.25,-0.25,0.0]],
76
+ [[-0.25,-0.25,0.0],[-0.25,-0.25,0.0],[-0.125,-0.125,0.125]],
77
+ [[0.125,0.125,0.125],[-0.25,-0.25,0.0],[-0.25,-0.25,0.0]],
78
+ [[-0.25,-0.25,0.0],[-0.25,-0.25,0.0]],
79
+ [[-0.125,-0.125,0.125]],
80
+ [[0.125,0.125,0.125],[-0.125,-0.125,0.125]],
81
+ [[-0.125,-0.125,0.125],[-0.125,-0.125,0.125]],
82
+ [[-0.125,-0.125,0.125],[-0.25,-0.25,0.0],[0.25,0.25,-0.0]],
83
+ [[0.0,-0.25,0.25],[0.0,-0.25,0.25]],
84
+ [[0.0,0.0,0.5],[0.25,-0.25,0.25],[0.125,-0.125,0.125]],
85
+ [[0.0,-0.25,0.25],[0.0,-0.25,0.25],[-0.125,-0.125,0.125]],
86
+ [[0.375,-0.375,0.375],[0.0,-0.25,-0.25],[-0.125,0.125,-0.125],[0.25,0.25,0.0]],
87
+ [[-0.125,-0.125,0.125],[-0.125,0.125,0.125]],
88
+ [[0.125,0.125,0.125],[-0.125,-0.125,0.125],[-0.125,0.125,0.125]],
89
+ [[-0.125,-0.125,0.125],[-0.25,0.0,0.25],[-0.25,0.0,0.25]],
90
+ [[0.5,0.0,0.0],[-0.25,-0.25,0.25],[-0.125,-0.125,0.125],[-0.125,-0.125,0.125]],
91
+ [[-0.0,0.5,0.0],[-0.25,0.25,-0.25],[0.125,-0.125,0.125]],
92
+ [[-0.25,0.25,-0.25],[-0.25,0.25,-0.25],[-0.125,0.125,-0.125],[-0.125,0.125,-0.125]],
93
+ [[-0.25,0.0,-0.25],[0.375,-0.375,-0.375],[0.0,0.25,-0.25],[-0.125,0.125,0.125]],
94
+ [[0.5,0.0,0.0],[-0.25,0.25,-0.25],[0.125,-0.125,0.125]],
95
+ [[-0.25,0.0,0.25],[0.25,0.0,-0.25]],
96
+ [[-0.0,0.0,0.5],[-0.25,0.25,0.25],[-0.125,0.125,0.125]],
97
+ [[-0.125,-0.125,0.125],[-0.25,0.0,0.25],[0.25,0.0,-0.25]],
98
+ [[-0.25,-0.0,-0.25],[-0.375,0.375,0.375],[-0.25,-0.25,0.0],[-0.125,0.125,0.125]],
99
+ [[0.0,0.0,-0.5],[0.25,0.25,-0.25],[-0.125,-0.125,0.125]],
100
+ [[-0.0,0.0,0.5],[0.0,0.0,0.5]],
101
+ [[0.125,0.125,0.125],[0.125,0.125,0.125],[0.25,0.25,0.25],[0.0,0.0,0.5]],
102
+ [[0.125,0.125,0.125],[0.25,0.25,0.25],[0.0,0.0,0.5]],
103
+ [[-0.25,0.0,0.25],[0.25,0.0,-0.25],[-0.125,0.125,0.125]],
104
+ [[-0.0,0.0,0.5],[0.25,-0.25,0.25],[0.125,-0.125,0.125],[0.125,-0.125,0.125]],
105
+ [[-0.25,0.0,0.25],[-0.25,0.0,0.25],[-0.25,0.0,0.25],[0.25,0.0,-0.25]],
106
+ [[0.125,-0.125,0.125],[0.25,0.0,0.25],[0.25,0.0,0.25]],
107
+ [[0.25,0.0,0.25],[-0.375,-0.375,0.375],[-0.25,0.25,0.0],[-0.125,-0.125,0.125]],
108
+ [[-0.0,0.0,0.5],[0.25,-0.25,0.25],[0.125,-0.125,0.125]],
109
+ [[0.125,0.125,0.125],[0.25,0.0,0.25],[0.25,0.0,0.25]],
110
+ [[0.25,0.0,0.25],[0.25,0.0,0.25]],
111
+ [[-0.125,-0.125,0.125],[0.125,-0.125,0.125]],
112
+ [[0.125,0.125,0.125],[-0.125,-0.125,0.125],[0.125,-0.125,0.125]],
113
+ [[-0.125,-0.125,0.125],[0.0,-0.25,0.25],[0.0,0.25,-0.25]],
114
+ [[0.0,-0.5,0.0],[0.125,0.125,-0.125],[0.25,0.25,-0.25],[-0.125,-0.125,0.125]],
115
+ [[0.0,-0.25,0.25],[0.0,-0.25,0.25],[0.125,-0.125,0.125]],
116
+ [[0.0,0.0,0.5],[0.25,-0.25,0.25],[0.125,-0.125,0.125],[0.125,-0.125,0.125]],
117
+ [[0.0,-0.25,0.25],[0.0,-0.25,0.25],[0.0,-0.25,0.25],[0.0,0.25,-0.25]],
118
+ [[0.0,0.25,0.25],[0.0,0.25,0.25],[0.125,-0.125,-0.125]],
119
+ [[-0.125,0.125,0.125],[0.125,-0.125,0.125],[-0.125,-0.125,0.125]],
120
+ [[-0.125,0.125,0.125],[0.125,-0.125,0.125],[-0.125,-0.125,0.125],[0.125,0.125,0.125]],
121
+ [[-0.0,0.0,0.5],[-0.25,-0.25,0.25],[-0.125,-0.125,0.125],[-0.125,-0.125,0.125]],
122
+ [[0.125,0.125,0.125],[0.125,-0.125,0.125],[0.125,-0.125,-0.125]],
123
+ [[-0.0,0.5,0.0],[-0.25,0.25,-0.25],[0.125,-0.125,0.125],[0.125,-0.125,0.125]],
124
+ [[0.125,0.125,0.125],[-0.125,-0.125,0.125],[0.125,-0.125,-0.125]],
125
+ [[0.0,-0.25,-0.25],[0.0,0.25,0.25],[0.125,0.125,0.125]],
126
+ [[0.125,0.125,0.125],[0.125,-0.125,-0.125]],
127
+ [[0.5,0.0,-0.0],[0.25,-0.25,-0.25],[0.125,-0.125,-0.125]],
128
+ [[-0.25,0.25,0.25],[-0.125,0.125,0.125],[-0.25,0.25,0.25],[0.125,-0.125,-0.125]],
129
+ [[0.375,-0.375,0.375],[0.0,0.25,0.25],[-0.125,0.125,-0.125],[-0.25,0.0,0.25]],
130
+ [[0.0,-0.5,0.0],[-0.25,0.25,0.25],[-0.125,0.125,0.125]],
131
+ [[-0.375,-0.375,0.375],[0.25,-0.25,0.0],[0.0,0.25,0.25],[-0.125,-0.125,0.125]],
132
+ [[-0.125,0.125,0.125],[-0.25,0.25,0.25],[0.0,0.0,0.5]],
133
+ [[0.125,0.125,0.125],[0.0,0.25,0.25],[0.0,0.25,0.25]],
134
+ [[0.0,0.25,0.25],[0.0,0.25,0.25]],
135
+ [[0.5,0.0,-0.0],[0.25,0.25,0.25],[0.125,0.125,0.125],[0.125,0.125,0.125]],
136
+ [[0.125,-0.125,0.125],[-0.125,-0.125,0.125],[0.125,0.125,0.125]],
137
+ [[-0.25,-0.0,-0.25],[0.25,0.0,0.25],[0.125,0.125,0.125]],
138
+ [[0.125,0.125,0.125],[0.125,-0.125,0.125]],
139
+ [[-0.25,-0.25,0.0],[0.25,0.25,-0.0],[0.125,0.125,0.125]],
140
+ [[0.125,0.125,0.125],[-0.125,-0.125,0.125]],
141
+ [[0.125,0.125,0.125],[0.125,0.125,0.125]],
142
+ [[0.125,0.125,0.125]],
143
+ [[0.125,0.125,0.125]],
144
+ [[0.125,0.125,0.125],[0.125,0.125,0.125]],
145
+ [[0.125,0.125,0.125],[-0.125,-0.125,0.125]],
146
+ [[-0.25,-0.25,0.0],[0.25,0.25,-0.0],[0.125,0.125,0.125]],
147
+ [[0.125,0.125,0.125],[0.125,-0.125,0.125]],
148
+ [[-0.25,-0.0,-0.25],[0.25,0.0,0.25],[0.125,0.125,0.125]],
149
+ [[0.125,-0.125,0.125],[-0.125,-0.125,0.125],[0.125,0.125,0.125]],
150
+ [[0.5,0.0,-0.0],[0.25,0.25,0.25],[0.125,0.125,0.125],[0.125,0.125,0.125]],
151
+ [[0.0,0.25,0.25],[0.0,0.25,0.25]],
152
+ [[0.125,0.125,0.125],[0.0,0.25,0.25],[0.0,0.25,0.25]],
153
+ [[-0.125,0.125,0.125],[-0.25,0.25,0.25],[0.0,0.0,0.5]],
154
+ [[-0.375,-0.375,0.375],[0.25,-0.25,0.0],[0.0,0.25,0.25],[-0.125,-0.125,0.125]],
155
+ [[0.0,-0.5,0.0],[-0.25,0.25,0.25],[-0.125,0.125,0.125]],
156
+ [[0.375,-0.375,0.375],[0.0,0.25,0.25],[-0.125,0.125,-0.125],[-0.25,0.0,0.25]],
157
+ [[-0.25,0.25,0.25],[-0.125,0.125,0.125],[-0.25,0.25,0.25],[0.125,-0.125,-0.125]],
158
+ [[0.5,0.0,-0.0],[0.25,-0.25,-0.25],[0.125,-0.125,-0.125]],
159
+ [[0.125,0.125,0.125],[0.125,-0.125,-0.125]],
160
+ [[0.0,-0.25,-0.25],[0.0,0.25,0.25],[0.125,0.125,0.125]],
161
+ [[0.125,0.125,0.125],[-0.125,-0.125,0.125],[0.125,-0.125,-0.125]],
162
+ [[-0.0,0.5,0.0],[-0.25,0.25,-0.25],[0.125,-0.125,0.125],[0.125,-0.125,0.125]],
163
+ [[0.125,0.125,0.125],[0.125,-0.125,0.125],[0.125,-0.125,-0.125]],
164
+ [[-0.0,0.0,0.5],[-0.25,-0.25,0.25],[-0.125,-0.125,0.125],[-0.125,-0.125,0.125]],
165
+ [[-0.125,0.125,0.125],[0.125,-0.125,0.125],[-0.125,-0.125,0.125],[0.125,0.125,0.125]],
166
+ [[-0.125,0.125,0.125],[0.125,-0.125,0.125],[-0.125,-0.125,0.125]],
167
+ [[0.0,0.25,0.25],[0.0,0.25,0.25],[0.125,-0.125,-0.125]],
168
+ [[0.0,-0.25,-0.25],[0.0,0.25,0.25],[0.0,0.25,0.25],[0.0,0.25,0.25]],
169
+ [[0.0,0.0,0.5],[0.25,-0.25,0.25],[0.125,-0.125,0.125],[0.125,-0.125,0.125]],
170
+ [[0.0,-0.25,0.25],[0.0,-0.25,0.25],[0.125,-0.125,0.125]],
171
+ [[0.0,-0.5,0.0],[0.125,0.125,-0.125],[0.25,0.25,-0.25],[-0.125,-0.125,0.125]],
172
+ [[-0.125,-0.125,0.125],[0.0,-0.25,0.25],[0.0,0.25,-0.25]],
173
+ [[0.125,0.125,0.125],[-0.125,-0.125,0.125],[0.125,-0.125,0.125]],
174
+ [[-0.125,-0.125,0.125],[0.125,-0.125,0.125]],
175
+ [[0.25,0.0,0.25],[0.25,0.0,0.25]],
176
+ [[0.125,0.125,0.125],[0.25,0.0,0.25],[0.25,0.0,0.25]],
177
+ [[-0.0,0.0,0.5],[0.25,-0.25,0.25],[0.125,-0.125,0.125]],
178
+ [[0.25,0.0,0.25],[-0.375,-0.375,0.375],[-0.25,0.25,0.0],[-0.125,-0.125,0.125]],
179
+ [[0.125,-0.125,0.125],[0.25,0.0,0.25],[0.25,0.0,0.25]],
180
+ [[-0.25,-0.0,-0.25],[0.25,0.0,0.25],[0.25,0.0,0.25],[0.25,0.0,0.25]],
181
+ [[-0.0,0.0,0.5],[0.25,-0.25,0.25],[0.125,-0.125,0.125],[0.125,-0.125,0.125]],
182
+ [[-0.25,0.0,0.25],[0.25,0.0,-0.25],[-0.125,0.125,0.125]],
183
+ [[0.125,0.125,0.125],[0.25,0.25,0.25],[0.0,0.0,0.5]],
184
+ [[0.125,0.125,0.125],[0.125,0.125,0.125],[0.25,0.25,0.25],[0.0,0.0,0.5]],
185
+ [[-0.0,0.0,0.5],[0.0,0.0,0.5]],
186
+ [[0.0,0.0,-0.5],[0.25,0.25,-0.25],[-0.125,-0.125,0.125]],
187
+ [[-0.25,-0.0,-0.25],[-0.375,0.375,0.375],[-0.25,-0.25,0.0],[-0.125,0.125,0.125]],
188
+ [[-0.125,-0.125,0.125],[-0.25,0.0,0.25],[0.25,0.0,-0.25]],
189
+ [[-0.0,0.0,0.5],[-0.25,0.25,0.25],[-0.125,0.125,0.125]],
190
+ [[-0.25,0.0,0.25],[0.25,0.0,-0.25]],
191
+ [[0.5,0.0,0.0],[-0.25,0.25,-0.25],[0.125,-0.125,0.125]],
192
+ [[-0.25,0.0,-0.25],[0.375,-0.375,-0.375],[0.0,0.25,-0.25],[-0.125,0.125,0.125]],
193
+ [[-0.25,0.25,-0.25],[-0.25,0.25,-0.25],[-0.125,0.125,-0.125],[-0.125,0.125,-0.125]],
194
+ [[-0.0,0.5,0.0],[-0.25,0.25,-0.25],[0.125,-0.125,0.125]],
195
+ [[0.5,0.0,0.0],[-0.25,-0.25,0.25],[-0.125,-0.125,0.125],[-0.125,-0.125,0.125]],
196
+ [[-0.125,-0.125,0.125],[-0.25,0.0,0.25],[-0.25,0.0,0.25]],
197
+ [[0.125,0.125,0.125],[-0.125,-0.125,0.125],[-0.125,0.125,0.125]],
198
+ [[-0.125,-0.125,0.125],[-0.125,0.125,0.125]],
199
+ [[0.375,-0.375,0.375],[0.0,-0.25,-0.25],[-0.125,0.125,-0.125],[0.25,0.25,0.0]],
200
+ [[0.0,-0.25,0.25],[0.0,-0.25,0.25],[-0.125,-0.125,0.125]],
201
+ [[0.0,0.0,0.5],[0.25,-0.25,0.25],[0.125,-0.125,0.125]],
202
+ [[0.0,-0.25,0.25],[0.0,-0.25,0.25]],
203
+ [[-0.125,-0.125,0.125],[-0.25,-0.25,0.0],[0.25,0.25,-0.0]],
204
+ [[-0.125,-0.125,0.125],[-0.125,-0.125,0.125]],
205
+ [[0.125,0.125,0.125],[-0.125,-0.125,0.125]],
206
+ [[-0.125,-0.125,0.125]],
207
+ [[-0.25,-0.25,0.0],[-0.25,-0.25,0.0]],
208
+ [[0.125,0.125,0.125],[-0.25,-0.25,0.0],[-0.25,-0.25,0.0]],
209
+ [[-0.25,-0.25,0.0],[-0.25,-0.25,0.0],[-0.125,-0.125,0.125]],
210
+ [[-0.25,-0.25,0.0],[-0.25,-0.25,0.0],[-0.25,-0.25,0.0],[0.25,0.25,-0.0]],
211
+ [[0.0,0.5,0.0],[0.25,0.25,-0.25],[-0.125,-0.125,0.125]],
212
+ [[-0.375,0.375,-0.375],[-0.25,-0.25,0.0],[-0.125,0.125,-0.125],[-0.25,0.0,0.25]],
213
+ [[0.0,0.5,0.0],[0.25,0.25,-0.25],[-0.125,-0.125,0.125],[-0.125,-0.125,0.125]],
214
+ [[-0.125,0.125,0.125],[0.25,-0.25,0.0],[-0.25,0.25,0.0]],
215
+ [[0.0,-0.5,0.0],[-0.25,-0.25,-0.25],[-0.125,-0.125,-0.125]],
216
+ [[0.125,0.125,0.125],[0.0,-0.5,0.0],[-0.25,-0.25,-0.25],[-0.125,-0.125,-0.125]],
217
+ [[-0.375,-0.375,-0.375],[-0.25,0.0,0.25],[-0.125,-0.125,-0.125],[-0.25,0.25,0.0]],
218
+ [[0.25,-0.25,0.0],[-0.25,0.25,0.0],[0.125,-0.125,0.125]],
219
+ [[0.0,0.5,0.0],[0.0,-0.5,0.0]],
220
+ [[0.0,0.5,0.0],[0.125,-0.125,0.125],[-0.25,0.25,-0.25]],
221
+ [[0.0,0.5,0.0],[-0.25,0.25,0.25],[0.125,-0.125,-0.125]],
222
+ [[0.25,-0.25,0.0],[-0.25,0.25,0.0]],
223
+ [[-0.5,0.0,0.0],[-0.25,-0.25,0.25],[-0.125,-0.125,0.125]],
224
+ [[0.0,0.25,-0.25],[0.375,-0.375,-0.375],[-0.125,0.125,0.125],[0.25,0.25,0.0]],
225
+ [[0.5,0.0,0.0],[0.25,-0.25,0.25],[-0.125,0.125,-0.125],[0.125,-0.125,0.125]],
226
+ [[0.125,-0.125,0.125],[0.25,-0.25,0.0],[0.25,-0.25,0.0]],
227
+ [[0.25,0.25,-0.25],[0.25,0.25,-0.25],[0.125,0.125,-0.125],[-0.125,-0.125,0.125]],
228
+ [[-0.0,0.0,0.5],[-0.25,-0.25,0.25],[-0.125,-0.125,0.125]],
229
+ [[0.125,0.125,0.125],[0.125,-0.125,0.125],[-0.125,0.125,0.125]],
230
+ [[-0.125,0.125,0.125],[0.125,-0.125,0.125]],
231
+ [[-0.375,-0.375,0.375],[-0.0,0.25,0.25],[0.125,0.125,-0.125],[-0.25,-0.0,-0.25]],
232
+ [[0.0,-0.25,0.25],[0.0,0.25,-0.25],[0.125,-0.125,0.125]],
233
+ [[0.125,-0.125,0.125],[-0.25,-0.0,-0.25],[0.25,0.0,0.25]],
234
+ [[0.125,-0.125,0.125],[0.125,-0.125,0.125]],
235
+ [[0.0,-0.5,0.0],[0.125,0.125,-0.125],[0.25,0.25,-0.25]],
236
+ [[0.0,-0.25,0.25],[0.0,0.25,-0.25]],
237
+ [[0.125,0.125,0.125],[0.125,-0.125,0.125]],
238
+ [[0.125,-0.125,0.125]],
239
+ [[-0.5,0.0,0.0],[-0.125,-0.125,-0.125],[-0.25,-0.25,-0.25]],
240
+ [[-0.5,0.0,0.0],[-0.125,-0.125,-0.125],[-0.25,-0.25,-0.25],[0.125,0.125,0.125]],
241
+ [[0.375,0.375,0.375],[0.0,0.25,-0.25],[-0.125,-0.125,-0.125],[-0.25,0.25,0.0]],
242
+ [[0.125,-0.125,-0.125],[0.25,-0.25,0.0],[0.25,-0.25,0.0]],
243
+ [[0.125,0.125,0.125],[0.375,0.375,0.375],[0.0,-0.25,0.25],[-0.25,0.0,0.25]],
244
+ [[-0.25,0.0,0.25],[-0.25,0.0,0.25],[0.125,-0.125,-0.125]],
245
+ [[0.0,-0.25,-0.25],[0.0,0.25,0.25],[-0.125,0.125,0.125]],
246
+ [[-0.125,0.125,0.125],[0.125,-0.125,-0.125]],
247
+ [[-0.125,-0.125,-0.125],[-0.25,-0.25,-0.25],[0.25,0.25,0.25],[0.125,0.125,0.125]],
248
+ [[-0.125,-0.125,0.125],[0.125,-0.125,0.125],[0.125,-0.125,-0.125]],
249
+ [[0.0,0.0,-0.5],[0.25,0.25,0.25],[-0.125,-0.125,-0.125]],
250
+ [[0.125,-0.125,0.125],[0.125,-0.125,-0.125]],
251
+ [[0.0,-0.5,0.0],[0.25,0.25,0.25],[0.125,0.125,0.125]],
252
+ [[-0.125,-0.125,0.125],[0.125,-0.125,-0.125]],
253
+ [[0.0,-0.25,-0.25],[0.0,0.25,0.25]],
254
+ [[0.125,-0.125,-0.125]],
255
+ [[0.5,0.0,0.0],[0.5,0.0,0.0]],
256
+ [[-0.5,0.0,0.0],[-0.25,0.25,0.25],[-0.125,0.125,0.125]],
257
+ [[0.5,0.0,0.0],[0.25,-0.25,0.25],[-0.125,0.125,-0.125]],
258
+ [[0.25,-0.25,0.0],[0.25,-0.25,0.0]],
259
+ [[0.5,0.0,0.0],[-0.25,-0.25,0.25],[-0.125,-0.125,0.125]],
260
+ [[-0.25,0.0,0.25],[-0.25,0.0,0.25]],
261
+ [[0.125,0.125,0.125],[-0.125,0.125,0.125]],
262
+ [[-0.125,0.125,0.125]],
263
+ [[0.5,0.0,-0.0],[0.25,0.25,0.25],[0.125,0.125,0.125]],
264
+ [[0.125,-0.125,0.125],[-0.125,-0.125,0.125]],
265
+ [[-0.25,-0.0,-0.25],[0.25,0.0,0.25]],
266
+ [[0.125,-0.125,0.125]],
267
+ [[-0.25,-0.25,0.0],[0.25,0.25,-0.0]],
268
+ [[-0.125,-0.125,0.125]],
269
+ [[0.125,0.125,0.125]],
270
+ [[0,0,0]]]
271
+
272
+
273
+ def compute_surface_distances(mask_gt, mask_pred, spacing_mm):
274
+ """Compute closest distances from all surface points to the other surface.
275
+
276
+ Finds all surface elements "surfels" in the ground truth mask `mask_gt` and
277
+ the predicted mask `mask_pred`, computes their area in mm^2 and the distance
278
+ to the closest point on the other surface. It returns two sorted lists of
279
+ distances together with the corresponding surfel areas. If one of the masks
280
+ is empty, the corresponding lists are empty and all distances in the other
281
+ list are `inf`
282
+
283
+ Args:
284
+ mask_gt: 3-dim Numpy array of type bool. The ground truth mask.
285
+ mask_pred: 3-dim Numpy array of type bool. The predicted mask.
286
+ spacing_mm: 3-element list-like structure. Voxel spacing in x0, x1 and x2
287
+ direction
288
+
289
+ Returns:
290
+ A dict with
291
+ "distances_gt_to_pred": 1-dim numpy array of type float. The distances in mm
292
+ from all ground truth surface elements to the predicted surface,
293
+ sorted from smallest to largest
294
+ "distances_pred_to_gt": 1-dim numpy array of type float. The distances in mm
295
+ from all predicted surface elements to the ground truth surface,
296
+ sorted from smallest to largest
297
+ "surfel_areas_gt": 1-dim numpy array of type float. The area in mm^2 of
298
+ the ground truth surface elements in the same order as
299
+ distances_gt_to_pred
300
+ "surfel_areas_pred": 1-dim numpy array of type float. The area in mm^2 of
301
+ the predicted surface elements in the same order as
302
+ distances_pred_to_gt
303
+
304
+ """
305
+
306
+ # compute the area for all 256 possible surface elements
307
+ # (given a 2x2x2 neighbourhood) according to the spacing_mm
308
+ neighbour_code_to_surface_area = np.zeros([256])
309
+ for code in range(256):
310
+ normals = np.array(neighbour_code_to_normals[code])
311
+ sum_area = 0
312
+ for normal_idx in range(normals.shape[0]):
313
+ # normal vector
314
+ n = np.zeros([3])
315
+ n[0] = normals[normal_idx,0] * spacing_mm[1] * spacing_mm[2]
316
+ n[1] = normals[normal_idx,1] * spacing_mm[0] * spacing_mm[2]
317
+ n[2] = normals[normal_idx,2] * spacing_mm[0] * spacing_mm[1]
318
+ area = np.linalg.norm(n)
319
+ sum_area += area
320
+ neighbour_code_to_surface_area[code] = sum_area
321
+
322
+ # compute the bounding box of the masks to trim
323
+ # the volume to the smallest possible processing subvolume
324
+ mask_all = mask_gt | mask_pred
325
+ bbox_min = np.zeros(3, np.int64)
326
+ bbox_max = np.zeros(3, np.int64)
327
+
328
+ # max projection to the x0-axis
329
+ proj_0 = np.max(np.max(mask_all, axis=2), axis=1)
330
+ idx_nonzero_0 = np.nonzero(proj_0)[0]
331
+ if len(idx_nonzero_0) == 0:
332
+ return {"distances_gt_to_pred": np.array([]),
333
+ "distances_pred_to_gt": np.array([]),
334
+ "surfel_areas_gt": np.array([]),
335
+ "surfel_areas_pred": np.array([])}
336
+
337
+ bbox_min[0] = np.min(idx_nonzero_0)
338
+ bbox_max[0] = np.max(idx_nonzero_0)
339
+
340
+ # max projection to the x1-axis
341
+ proj_1 = np.max(np.max(mask_all, axis=2), axis=0)
342
+ idx_nonzero_1 = np.nonzero(proj_1)[0]
343
+ bbox_min[1] = np.min(idx_nonzero_1)
344
+ bbox_max[1] = np.max(idx_nonzero_1)
345
+
346
+ # max projection to the x2-axis
347
+ proj_2 = np.max(np.max(mask_all, axis=1), axis=0)
348
+ idx_nonzero_2 = np.nonzero(proj_2)[0]
349
+ bbox_min[2] = np.min(idx_nonzero_2)
350
+ bbox_max[2] = np.max(idx_nonzero_2)
351
+
352
+ # print("bounding box min = {}".format(bbox_min))
353
+ # print("bounding box max = {}".format(bbox_max))
354
+
355
+ # crop the processing subvolume.
356
+ # we need to zeropad the cropped region with 1 voxel at the lower,
357
+ # the right and the back side. This is required to obtain the "full"
358
+ # convolution result with the 2x2x2 kernel
359
+ cropmask_gt = np.zeros((bbox_max - bbox_min)+2, np.uint8)
360
+ cropmask_pred = np.zeros((bbox_max - bbox_min)+2, np.uint8)
361
+
362
+ cropmask_gt[0:-1, 0:-1, 0:-1] = mask_gt[bbox_min[0]:bbox_max[0]+1,
363
+ bbox_min[1]:bbox_max[1]+1,
364
+ bbox_min[2]:bbox_max[2]+1]
365
+
366
+ cropmask_pred[0:-1, 0:-1, 0:-1] = mask_pred[bbox_min[0]:bbox_max[0]+1,
367
+ bbox_min[1]:bbox_max[1]+1,
368
+ bbox_min[2]:bbox_max[2]+1]
369
+
370
+ # compute the neighbour code (local binary pattern) for each voxel
371
+ # the resultsing arrays are spacially shifted by minus half a voxel in each axis.
372
+ # i.e. the points are located at the corners of the original voxels
373
+ kernel = np.array([[[128,64],
374
+ [32,16]],
375
+ [[8,4],
376
+ [2,1]]])
377
+ neighbour_code_map_gt = scipy.ndimage.filters.correlate(cropmask_gt.astype(np.uint8), kernel, mode="constant", cval=0)
378
+ neighbour_code_map_pred = scipy.ndimage.filters.correlate(cropmask_pred.astype(np.uint8), kernel, mode="constant", cval=0)
379
+
380
+ # create masks with the surface voxels
381
+ borders_gt = ((neighbour_code_map_gt != 0) & (neighbour_code_map_gt != 255))
382
+ borders_pred = ((neighbour_code_map_pred != 0) & (neighbour_code_map_pred != 255))
383
+
384
+ # compute the distance transform (closest distance of each voxel to the surface voxels)
385
+ if borders_gt.any():
386
+ distmap_gt = scipy.ndimage.morphology.distance_transform_edt(~borders_gt, sampling=spacing_mm)
387
+ else:
388
+ distmap_gt = np.Inf * np.ones(borders_gt.shape)
389
+
390
+ if borders_pred.any():
391
+ distmap_pred = scipy.ndimage.morphology.distance_transform_edt(~borders_pred, sampling=spacing_mm)
392
+ else:
393
+ distmap_pred = np.Inf * np.ones(borders_pred.shape)
394
+
395
+ # compute the area of each surface element
396
+ surface_area_map_gt = neighbour_code_to_surface_area[neighbour_code_map_gt]
397
+ surface_area_map_pred = neighbour_code_to_surface_area[neighbour_code_map_pred]
398
+
399
+ # create a list of all surface elements with distance and area
400
+ distances_gt_to_pred = distmap_pred[borders_gt]
401
+ distances_pred_to_gt = distmap_gt[borders_pred]
402
+ surfel_areas_gt = surface_area_map_gt[borders_gt]
403
+ surfel_areas_pred = surface_area_map_pred[borders_pred]
404
+
405
+ # sort them by distance
406
+ if distances_gt_to_pred.shape != (0,):
407
+ sorted_surfels_gt = np.array(sorted(zip(distances_gt_to_pred, surfel_areas_gt)))
408
+ distances_gt_to_pred = sorted_surfels_gt[:,0]
409
+ surfel_areas_gt = sorted_surfels_gt[:,1]
410
+
411
+ if distances_pred_to_gt.shape != (0,):
412
+ sorted_surfels_pred = np.array(sorted(zip(distances_pred_to_gt, surfel_areas_pred)))
413
+ distances_pred_to_gt = sorted_surfels_pred[:,0]
414
+ surfel_areas_pred = sorted_surfels_pred[:,1]
415
+
416
+
417
+ return {"distances_gt_to_pred": distances_gt_to_pred,
418
+ "distances_pred_to_gt": distances_pred_to_gt,
419
+ "surfel_areas_gt": surfel_areas_gt,
420
+ "surfel_areas_pred": surfel_areas_pred}
421
+
422
+
423
+ def compute_average_surface_distance(surface_distances):
424
+ distances_gt_to_pred = surface_distances["distances_gt_to_pred"]
425
+ distances_pred_to_gt = surface_distances["distances_pred_to_gt"]
426
+ surfel_areas_gt = surface_distances["surfel_areas_gt"]
427
+ surfel_areas_pred = surface_distances["surfel_areas_pred"]
428
+ average_distance_gt_to_pred = np.sum( distances_gt_to_pred * surfel_areas_gt) / np.sum(surfel_areas_gt)
429
+ average_distance_pred_to_gt = np.sum( distances_pred_to_gt * surfel_areas_pred) / np.sum(surfel_areas_pred)
430
+ return (average_distance_gt_to_pred, average_distance_pred_to_gt)
431
+
432
+ def compute_robust_hausdorff(surface_distances, percent):
433
+ distances_gt_to_pred = surface_distances["distances_gt_to_pred"]
434
+ distances_pred_to_gt = surface_distances["distances_pred_to_gt"]
435
+ surfel_areas_gt = surface_distances["surfel_areas_gt"]
436
+ surfel_areas_pred = surface_distances["surfel_areas_pred"]
437
+ if len(distances_gt_to_pred) > 0:
438
+ surfel_areas_cum_gt = np.cumsum(surfel_areas_gt) / np.sum(surfel_areas_gt)
439
+ idx = np.searchsorted(surfel_areas_cum_gt, percent/100.0)
440
+ perc_distance_gt_to_pred = distances_gt_to_pred[min(idx, len(distances_gt_to_pred)-1)]
441
+ else:
442
+ perc_distance_gt_to_pred = np.Inf
443
+
444
+ if len(distances_pred_to_gt) > 0:
445
+ surfel_areas_cum_pred = np.cumsum(surfel_areas_pred) / np.sum(surfel_areas_pred)
446
+ idx = np.searchsorted(surfel_areas_cum_pred, percent/100.0)
447
+ perc_distance_pred_to_gt = distances_pred_to_gt[min(idx, len(distances_pred_to_gt)-1)]
448
+ else:
449
+ perc_distance_pred_to_gt = np.Inf
450
+
451
+ return max( perc_distance_gt_to_pred, perc_distance_pred_to_gt)
452
+
453
+ def compute_surface_overlap_at_tolerance(surface_distances, tolerance_mm):
454
+ distances_gt_to_pred = surface_distances["distances_gt_to_pred"]
455
+ distances_pred_to_gt = surface_distances["distances_pred_to_gt"]
456
+ surfel_areas_gt = surface_distances["surfel_areas_gt"]
457
+ surfel_areas_pred = surface_distances["surfel_areas_pred"]
458
+ rel_overlap_gt = np.sum(surfel_areas_gt[distances_gt_to_pred <= tolerance_mm]) / np.sum(surfel_areas_gt)
459
+ rel_overlap_pred = np.sum(surfel_areas_pred[distances_pred_to_gt <= tolerance_mm]) / np.sum(surfel_areas_pred)
460
+ return (rel_overlap_gt, rel_overlap_pred)
461
+
462
+ def compute_surface_dice_at_tolerance(surface_distances, tolerance_mm):
463
+ distances_gt_to_pred = surface_distances["distances_gt_to_pred"]
464
+ distances_pred_to_gt = surface_distances["distances_pred_to_gt"]
465
+ surfel_areas_gt = surface_distances["surfel_areas_gt"]
466
+ surfel_areas_pred = surface_distances["surfel_areas_pred"]
467
+ overlap_gt = np.sum(surfel_areas_gt[distances_gt_to_pred <= tolerance_mm])
468
+ overlap_pred = np.sum(surfel_areas_pred[distances_pred_to_gt <= tolerance_mm])
469
+ surface_dice = (overlap_gt + overlap_pred) / (
470
+ np.sum(surfel_areas_gt) + np.sum(surfel_areas_pred))
471
+ return surface_dice
472
+
473
+
474
+ def compute_dice_coefficient(mask_gt, mask_pred):
475
+ """Compute soerensen-dice coefficient.
476
+
477
+ compute the soerensen-dice coefficient between the ground truth mask `mask_gt`
478
+ and the predicted mask `mask_pred`.
479
+
480
+ Args:
481
+ mask_gt: 3-dim Numpy array of type bool. The ground truth mask.
482
+ mask_pred: 3-dim Numpy array of type bool. The predicted mask.
483
+
484
+ Returns:
485
+ the dice coeffcient as float. If both masks are empty, the result is NaN
486
+ """
487
+ volume_sum = mask_gt.sum() + mask_pred.sum()
488
+ if volume_sum == 0:
489
+ return np.NaN
490
+ volume_intersect = (mask_gt & mask_pred).sum()
491
+ return 2*volume_intersect / volume_sum
492
+
evaluate/__init__.py ADDED
File without changes
evaluate/evaluator.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+
4
+ import torch
5
+ from torch.cuda.amp import autocast as autocast
6
+ from tqdm import tqdm
7
+ from einops import rearrange, repeat, reduce
8
+ import numpy as np
9
+ import pandas as pd
10
+ from pathlib import Path
11
+ import nibabel as nib
12
+ import shutil
13
+ import pickle
14
+ from scipy.ndimage import gaussian_filter
15
+ import torch.distributed as dist
16
+
17
+ from evaluate.metric import calculate_metric_percase
18
+ from evaluate.merge_after_evaluate import merge
19
+ from train.dist import is_master
20
+
21
+ def compute_gaussian(tile_size, sigma_scale: float = 1. / 8, value_scaling_factor: float = 10, dtype=np.float16):
22
+ tmp = np.zeros(tile_size)
23
+ center_coords = [i // 2 for i in tile_size]
24
+ sigmas = [i * sigma_scale for i in tile_size]
25
+ tmp[tuple(center_coords)] = 1
26
+ gaussian_importance_map = gaussian_filter(tmp, sigmas, 0, mode='constant', cval=0)
27
+
28
+ # gaussian_importance_map = torch.from_numpy(gaussian_importance_map)
29
+
30
+ gaussian_importance_map = gaussian_importance_map / np.max(gaussian_importance_map) * value_scaling_factor
31
+ gaussian_importance_map = gaussian_importance_map.astype(dtype)
32
+
33
+ # gaussian_importance_map cannot be 0, otherwise we may end up with nans!
34
+ gaussian_importance_map[gaussian_importance_map == 0] = np.min(
35
+ gaussian_importance_map[gaussian_importance_map != 0])
36
+
37
+ return gaussian_importance_map
38
+
39
+ def evaluate(model,
40
+ text_encoder,
41
+ device,
42
+ testset,
43
+ testloader,
44
+ dice_score,
45
+ nsd_score,
46
+ csv_path,
47
+ resume,
48
+ save_interval,
49
+ visualization):
50
+
51
+ # if to store pred、gt、img (as nii.gz
52
+ if visualization:
53
+ nib_dir = csv_path.replace('.csv', '')
54
+
55
+ # collate in master process
56
+ if is_master():
57
+ # datasets --> labels --> metrics
58
+ datasets_labels_metrics = {} # {'COVID19':{'covid19_infection':{'dice':[0.8, 0.9, ...], ...} ...}, ...}
59
+
60
+ # datasets --> samples --> labels --> metrics
61
+ samples_labels_metrics = {} # {'COVID19':{'0.npy':{'covid19_infection':{'dice':0.8, ...} ...}, ...} 记录每个dataset里的sample(行)
62
+
63
+ # datsets --> labels
64
+ datasets_labels_sets = {} # {'COVID19':set('covid19_infection', ...), ...} 记录每个dataset里的label种类(列)
65
+
66
+ # accumulate scores of each sample in each process
67
+ results_of_samples = [] # each element : [dataset_name, modality, sample_id, scores_of_labels(dict), label_names]
68
+
69
+ # load results from an interrupted eval (only in master process)
70
+ if resume and is_master():
71
+ root_dir = os.path.dirname(csv_path)
72
+ prefix = os.path.basename(csv_path).replace('.csv', '_tmp_rank') # xxx/test/step_xxx.csv --> step_xxx_tmp_rank
73
+ pkl_to_del = []
74
+ for f in os.listdir(root_dir):
75
+ if prefix in f:
76
+ # load list of results
77
+ pkl_path = f'{root_dir}/{f}'
78
+ with open(pkl_path, 'rb') as f:
79
+ results_of_samples += pickle.load(f)
80
+ print(f'Load results from {pkl_path}')
81
+ pkl_to_del.append(pkl_path)
82
+
83
+ # there may be duplication? We leave the deduplication to the final merge
84
+ # merge all the loaded samples, del the tmp pickle files in previous evaluation task
85
+ for pkl_path in pkl_to_del:
86
+ os.remove(pkl_path)
87
+ print(f'Del {pkl_path}')
88
+ merge_pkl = csv_path.replace('.csv', f'_tmp_rank0.pkl')
89
+ with open(merge_pkl, 'wb') as f:
90
+ pickle.dump(results_of_samples, f)
91
+ print(f'Load results of {len(results_of_samples)} samples, Merge into {merge_pkl}')
92
+
93
+ model.eval()
94
+ text_encoder.eval()
95
+
96
+ with torch.no_grad():
97
+
98
+ data_time = 0
99
+ pred_time = 0
100
+ metric_time = 0
101
+
102
+ avg_patch_batch_num = 0
103
+ avg_query_batch_num = 0
104
+
105
+ # in ddp, only master process display the progress bar
106
+ if is_master():
107
+ testloader = tqdm(testloader, disable=False)
108
+ else:
109
+ testloader = tqdm(testloader, disable=True)
110
+
111
+ # gaussian kernel to accumulate predcition
112
+ gaussian = torch.tensor(compute_gaussian((288, 288, 96))).to(device) # hwd
113
+
114
+ end_time = time.time()
115
+ for sample in testloader: # in evaluation/inference, a "batch" in loader is a volume
116
+ # data loading
117
+ dataset_name = sample['dataset_name']
118
+ sample_id = sample['sample_id']
119
+ batched_patches = sample['batched_patches']
120
+ batched_y1y2_x1x2_z1z2 = sample['batched_y1y2_x1x2_z1z2']
121
+ labels = sample['labels']
122
+ gt_segmentation = sample['gt_segmentation'].numpy() # n h w d
123
+ modality = sample['modality']
124
+ image_path = sample['image_path']
125
+
126
+ n,h,w,d = gt_segmentation.shape
127
+ prediction = torch.zeros((n, h, w, d))
128
+ accumulation = torch.zeros((n, h, w, d))
129
+
130
+ data_time += (time.time()-end_time)
131
+ end_time = time.time()
132
+
133
+ with autocast():
134
+
135
+ queries = text_encoder(labels, modality)
136
+
137
+ # for each batch of patches, query with all labels
138
+ for patches, y1y2_x1x2_z1z2_ls in zip(batched_patches, batched_y1y2_x1x2_z1z2): # [b, c, h, w, d]
139
+ patches = patches.to(device=device)
140
+ prediction_patch = model(queries=queries, image_input=patches, train_mode=False)
141
+ prediction_patch = torch.sigmoid(prediction_patch) # bnhwd
142
+ prediction_patch = prediction_patch.detach() # .cpu().numpy()
143
+
144
+ # fill in
145
+ for b in range(len(y1y2_x1x2_z1z2_ls)):
146
+ y1, y2, x1, x2, z1, z2 = y1y2_x1x2_z1z2_ls[b]
147
+
148
+ # gaussian accumulation
149
+ tmp = prediction_patch[b, :, :y2-y1, :x2-x1, :z2-z1] * gaussian[:y2-y1, :x2-x1, :z2-z1] # on gpu
150
+ prediction[:, y1:y2, x1:x2, z1:z2] += tmp.cpu()
151
+ accumulation[:, y1:y2, x1:x2, z1:z2] += gaussian[:y2-y1, :x2-x1, :z2-z1].cpu()
152
+
153
+ pred_time += (time.time()-end_time)
154
+ end_time = time.time()
155
+
156
+ # avg
157
+ prediction = prediction / accumulation
158
+ prediction = torch.where(prediction>0.5, 1.0, 0.0)
159
+ prediction = prediction.numpy()
160
+
161
+ # cal metrics : [{'dice':x, ...}, ...]
162
+ scores = []
163
+ for j in range(len(labels)):
164
+ scores.append(calculate_metric_percase(prediction[j, :, :, :], gt_segmentation[j, :, :, :], dice_score, nsd_score)) # {'dice':0.9, 'nsd':0.8} 每个label一个dict
165
+
166
+ # visualization
167
+ if visualization:
168
+ Path(f'{nib_dir}/{dataset_name}').mkdir(exist_ok=True, parents=True)
169
+ # 将image、gt和prediction保存下来
170
+ results = np.zeros((h, w, d)) # hwd
171
+ for j, label in enumerate(labels):
172
+ results += prediction[j, :, :, :] * (j+1) # 0 --> 1 (skip background)
173
+ Path(f'{nib_dir}/{dataset_name}/seg_{sample_id}').mkdir(exist_ok=True, parents=True)
174
+ # 每个label单独一个nii.gz
175
+ segobj = nib.nifti2.Nifti1Image(prediction[j, :, :, :], np.eye(4))
176
+ nib.save(segobj, f'{nib_dir}/{dataset_name}/seg_{sample_id}/{label}.nii.gz')
177
+ segobj = nib.nifti2.Nifti1Image(results, np.eye(4))
178
+ nib.save(segobj, f'{nib_dir}/{dataset_name}/seg_{sample_id}.nii.gz')
179
+
180
+ image = testset.load_image(image_path)
181
+ image = np.squeeze(image)
182
+ imgobj = nib.nifti2.Nifti1Image(image, np.eye(4))
183
+ nib.save(imgobj, f'{nib_dir}/{dataset_name}/img_{sample_id}.nii.gz')
184
+
185
+ gt = np.zeros((h, w, d)) # hwd
186
+ for j, label in enumerate(labels):
187
+ gt += gt_segmentation[j, :, :, :] * (j+1) # 0 --> 1 (skip background)
188
+ Path(f'{nib_dir}/{dataset_name}/gt_{sample_id}').mkdir(exist_ok=True, parents=True)
189
+ # 每个label单独一个nii.gz
190
+ segobj = nib.nifti2.Nifti1Image(gt_segmentation[j, :, :, :], np.eye(4))
191
+ nib.save(segobj, f'{nib_dir}/{dataset_name}/gt_{sample_id}/{label}.nii.gz')
192
+ gtobj = nib.nifti2.Nifti1Image(gt, np.eye(4))
193
+ nib.save(gtobj, f'{nib_dir}/{dataset_name}/gt_{sample_id}.nii.gz')
194
+
195
+ metric_time += (time.time()-end_time)
196
+ end_time = time.time()
197
+
198
+ # accumulate
199
+ results_of_samples.append([dataset_name, modality, sample_id, scores, labels])
200
+
201
+ # save in each process regularly in case of interruption
202
+ if len(results_of_samples) % save_interval == 0:
203
+ with open(csv_path.replace('.csv', f'_tmp_rank{dist.get_rank()}.pkl'), 'wb') as f:
204
+ pickle.dump(results_of_samples, f)
205
+
206
+ """
207
+ # gather results from all device to rank-0 (solution 1)
208
+ gather_results = [None for i in range(dist.get_world_size())]
209
+ dist.gather_object(
210
+ results_of_samples,
211
+ gather_results if dist.get_rank() == 0 else None,
212
+ dst = 0
213
+ )
214
+
215
+ if int(dist.get_rank()) == 0:
216
+ results_of_samples = [tmp for ls in results_of_samples for tmp in ls]
217
+ """
218
+
219
+ avg_patch_batch_num /= len(testloader)
220
+ avg_query_batch_num /= len(testloader)
221
+ data_time /= len(testloader)
222
+ pred_time /= len(testloader)
223
+ metric_time /= len(testloader)
224
+ print(f'On Rank {dist.get_rank()}, each sample has {avg_patch_batch_num} batch of patches and {avg_query_batch_num} batch of queries, Data Time: {data_time}, Pred Time: {pred_time}, Dice Time: {metric_time}')
225
+
226
+ torch.cuda.empty_cache()
227
+
228
+ # save in each process (to a fnl pickle, also denoting this process ends)
229
+ with open(csv_path.replace('.csv', f'_fnl_rank{dist.get_rank()}.pkl'), 'wb') as f:
230
+ pickle.dump(results_of_samples, f)
231
+
232
+ # gather and record in rank 0 (solution 2)
233
+ if is_master():
234
+
235
+ # detect the finish of each process
236
+ while True:
237
+ all_process_finished = True
238
+ for rank_id in range(torch.distributed.get_world_size()):
239
+ if not os.path.exists(csv_path.replace('.csv', f'_fnl_rank{rank_id}.pkl')): # xxx_tmp_rankx.pkl
240
+ all_process_finished = False
241
+ break
242
+ if all_process_finished:
243
+ break
244
+ else:
245
+ time.sleep(10)
246
+
247
+ # read results of each process (samples may be duplicated due to the even distribution of ddp, check)
248
+ results_of_samples = []
249
+ for rank_id in range(torch.distributed.get_world_size()):
250
+ fnl_results_file = csv_path.replace('.csv', f'_fnl_rank{rank_id}.pkl')
251
+ tmp_results_file = csv_path.replace('.csv', f'_tmp_rank{rank_id}.pkl')
252
+ with open(fnl_results_file, 'rb') as f:
253
+ results_of_samples += pickle.load(f)
254
+ os.remove(fnl_results_file)
255
+ if os.path.exists(tmp_results_file):
256
+ os.remove(tmp_results_file)
257
+
258
+ # check duplication
259
+ unique_set = set()
260
+ deduplicated_results_of_samples = []
261
+ for dataset_name, modality, sample_id, scores, labels in results_of_samples:
262
+ if f'{dataset_name}/{sample_id}' not in unique_set:
263
+ unique_set.add(f'{dataset_name}/{sample_id}')
264
+ deduplicated_results_of_samples.append([dataset_name, modality, sample_id, scores, labels])
265
+ results_of_samples = deduplicated_results_of_samples
266
+
267
+ # save for tmp
268
+ with open(csv_path.replace('.csv', '.pkl'), 'wb') as f:
269
+ pickle.dump(results_of_samples, f)
270
+
271
+ # collate results
272
+ for dataset_name, modality, sample_id, scores, labels in results_of_samples: # [[dataset_name, modality, sample_id, scores_of_labels(dict), label_names], ...]
273
+ dataset_name = f'{dataset_name}({modality})'
274
+
275
+ if dataset_name not in datasets_labels_metrics:
276
+ datasets_labels_metrics[dataset_name] = {} # {'COVID19(CT)':{}}
277
+ if dataset_name not in datasets_labels_sets:
278
+ datasets_labels_sets[dataset_name] = set() # {'COVID19(CT)':set()}
279
+ if dataset_name not in samples_labels_metrics:
280
+ samples_labels_metrics[dataset_name] = {}
281
+ samples_labels_metrics[dataset_name][sample_id] = {} # {'COVID19(CT)':{'0':{}}}
282
+
283
+ for metric_dict, label in zip(scores, labels):
284
+ # accumulate metrics (for per dataset per class
285
+ # {'COVID19(CT)':{'covid19_infection':{'dice':[0.8, 0.9, ...], 'nsd':[0.8, 0.9, ...], ...} ...}, ...}
286
+ if label not in datasets_labels_metrics[dataset_name]:
287
+ datasets_labels_metrics[dataset_name][label] = {k:[v] for k,v in metric_dict.items()}
288
+ else:
289
+ for k,v in metric_dict.items():
290
+ datasets_labels_metrics[dataset_name][label][k].append(v)
291
+
292
+ # statistic labels
293
+ # {'COVID19(CT)':set('covid19_infection', ...)}
294
+ if label not in datasets_labels_sets[dataset_name]:
295
+ datasets_labels_sets[dataset_name].add(label)
296
+
297
+ # record metrics (for per dataset per sample per class
298
+ # {'COVID19':{'0.npy':{'covid19_infection':{'dice':0.8, 'nsd':0.9, ...} ...}, ...}
299
+ samples_labels_metrics[dataset_name][sample_id][label] = {k:v for k,v in metric_dict.items()}
300
+
301
+ # average and log (列为metrics,例如dice,nsd...)
302
+ # create a df like:
303
+ # {
304
+ # 'TotalSegmentator': [0.xx, 0.xx, ...] # 在T之前,这是一列
305
+ # 'TotalSegmentator, Lung': [0.68, 0.72, ...]
306
+ # }
307
+ # by defult, print the dice (1st metric) of each dataset
308
+ info = 'Metrics of Each Dataset:\n'
309
+ avg_df = {}
310
+ for dataset in datasets_labels_metrics.keys():
311
+ avg_df[dataset] = {k:[] for k in metric_dict.keys()} # 'TotalSegmentator(CT)': {'dice':[0.8, ...] 'nsd':[0.5, ...], ...}
312
+ for label in datasets_labels_metrics[dataset].keys():
313
+ avg_df[f'{dataset}, {label}'] = []
314
+ for metric in datasets_labels_metrics[dataset][label].keys():
315
+ label_metric = np.average(datasets_labels_metrics[dataset][label][metric])
316
+ avg_df[f'{dataset}, {label}'].append(label_metric) # 'TotalSegmentator, Lung': [0.68, 0.72, ...] list of num_metrics
317
+ avg_df[dataset][metric].append(label_metric)
318
+ avg_df[dataset] = {k:np.average(v) for k,v in avg_df[dataset].items()} # 'TotalSegmentator': {'dice':[0.8, ...] 'nsd':[0.5, ...], ...} --> 'TotalSegmentator': {'dice':0.x, 'nsd':0.x, ...}
319
+ info += f'{dataset} | '
320
+ for k ,v in avg_df[dataset].items():
321
+ info += f'{v}({k}) | '
322
+ info += '\n'
323
+ avg_df[dataset] = list(avg_df[dataset].values())
324
+ avg_df = pd.DataFrame(avg_df).T
325
+ avg_df.columns = list(metric_dict.keys()) # ['dice', 'nsd']
326
+ avg_df.to_csv(csv_path)
327
+ print(info)
328
+
329
+ # detailed log (nsd和dice,列为class label
330
+ # multi-sheet, two for each dataset
331
+ df_list = [['summary', avg_df]]
332
+ for dataset, label_set in datasets_labels_sets.items():
333
+ metric_df ={}
334
+ if dice_score:
335
+ metric_df['dice'] = {}
336
+ if nsd_score:
337
+ metric_df['nsd'] = {}
338
+
339
+ # create dfs like:
340
+ # {
341
+ # '0.npy': [0.xx, 0.xx, ...]
342
+ # ......
343
+ # }
344
+
345
+ # {'COVID19':{'0.npy':{'covid19_infection':{'dice':0.8, ...} ...}, ...}
346
+ for image_id, label_dict in samples_labels_metrics[dataset].items():
347
+ for metric in metric_df:
348
+ tmp = [] # one dice for each label in this dataset
349
+ for label in label_set:
350
+ score = label_dict[label][metric] if label in label_dict else -1
351
+ tmp.append(score)
352
+ metric_df[metric][image_id] = tmp
353
+
354
+ for metric, metric_df in metric_df.items():
355
+ metric_df = pd.DataFrame(metric_df).T
356
+ metric_df.columns = list(label_set)
357
+ df_list.append([dataset+f'({metric})', metric_df])
358
+
359
+ xlsx_path = csv_path.replace('.csv', '.xlsx')
360
+ with pd.ExcelWriter(xlsx_path) as writer:
361
+ for name, df in df_list:
362
+ # 将每个 DataFrame 写入一个 sheet(sheet name must be < 31)
363
+ if len(name) > 31:
364
+ name = name[len(name)-31:]
365
+ df.to_excel(writer, sheet_name=name, index=True)
366
+
367
+ # avg_dice_over_merged_labels, avg_nsd_over_merged_labels = merge(region_split_json, label_statistic_json, xlsx_path, xlsx_path)
368
+
369
+ os.remove(csv_path.replace('.csv', '.pkl'))
370
+
371
+ else:
372
+
373
+ pass
374
+
375
+ # avg_dice_over_merged_labels = avg_nsd_over_merged_labels = 0
376
+
377
+ return # avg_dice_over_merged_labels, avg_nsd_over_merged_labels
378
+
379
+
evaluate/merge_after_evaluate.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import pandas as pd
4
+ import openpyxl
5
+
6
+ def merge(mod_label_json, mod_label_statistic, xlsx2load, xlsx2save):
7
+ mod_lab2dice = {}
8
+
9
+ # Load the first sheet of the Excel file
10
+ excel_file_path = xlsx2load
11
+ df = pd.read_excel(excel_file_path, sheet_name=0)
12
+ has_nsd = True if len(df.columns) > 2 else False
13
+
14
+ # 将Dataset Merged 写入新的工作表
15
+ workbook = openpyxl.load_workbook(xlsx2load)
16
+ new_sheet = workbook.create_sheet(title='Dataset Merge', index=1)
17
+ new_sheet.cell(row=1, column=1, value='Dataset')
18
+ new_sheet.cell(row=1, column=2, value='Dice')
19
+ new_sheet.cell(row=1, column=3, value='NSD')
20
+ row = 2
21
+ for i in range(0, len(df)):
22
+ if ',' not in df.iloc[i, 0]:
23
+ new_sheet.cell(row=row, column=1, value=df.iloc[i, 0])
24
+ new_sheet.cell(row=row, column=2, value=df.iloc[i, 1])
25
+ if has_nsd:
26
+ new_sheet.cell(row=row, column=3, value=df.iloc[i, 2])
27
+ row += 1
28
+
29
+ # with pd.ExcelWriter(xlsx2save, engine='openpyxl', mode='a', if_sheet_exists='new') as writer:
30
+ # filtered_df.to_excel(writer, sheet_name='Dataset Merge', index=False)
31
+
32
+ # 选取前两列
33
+ dataset_label_ls = df.iloc[:, 0]
34
+ dice_ls = df.iloc[:, 1]
35
+ nsd_ls = df.iloc[:, 2] if has_nsd else [0] * len(df)
36
+
37
+ for dataset_modality_label, dice, nsd in zip(dataset_label_ls, dice_ls, nsd_ls): # MSD_Pancreas(ct), pancreas 0.89
38
+ if ', ' not in dataset_modality_label:
39
+ continue
40
+ dataset_modality, label = dataset_modality_label.split(', ')
41
+ label = label.lower() # pancreas
42
+ # label = merge_label(label)
43
+ modality = dataset_modality.split('(')[-1].split(')')[0] # ct
44
+
45
+ # unique id : modality_label
46
+ mod_lab = f'{modality}_{label}'
47
+
48
+ # accumulate : dice and where the dice comes from (dataset, label, modality)
49
+ if mod_lab not in mod_lab2dice:
50
+ mod_lab2dice[mod_lab] = {'dice':[], 'nsd':[], 'merge':[]}
51
+ mod_lab2dice[mod_lab]['dice'].append(dice)
52
+ mod_lab2dice[mod_lab]['nsd'].append(nsd)
53
+ mod_lab2dice[mod_lab]['merge'].append(dataset_modality_label)
54
+
55
+ # retrieval regions
56
+ with open(mod_label_json, 'r') as f:
57
+ dict = json.load(f)
58
+ region2label = dict['region_based']
59
+ for region, label_ls in region2label.items():
60
+ region2label[region] = [mod_lab.split('_')[-1] for mod_lab in label_ls] # 去除modality
61
+ region2label['abnormal'] = [mod_lab.split('_')[-1] for mod_lab in dict['abnormal']]
62
+
63
+ region_dice_ls = {k:[] for k in region2label.keys()} # {'brain':[0.9, ...], ...}
64
+ region_nsd_ls = {k:[] for k in region2label.keys()} # {'brain':[0.9, ...], ...}
65
+ region_merge_ls = {k:[] for k in region2label.keys()} # {'brain':['frontal lobe', ...], ...}
66
+
67
+ mod_lab_ls = []
68
+ dice_ls = []
69
+ nsd_ls = []
70
+ merge_ls = []
71
+ region_ls = []
72
+ for mod_lab, dict in mod_lab2dice.items():
73
+ label = mod_lab.split('_')[-1]
74
+ mod_lab_ls.append(mod_lab)
75
+ dice_ls.append(sum(dict['dice'])/len(dict['dice']))
76
+ nsd_ls.append(sum(dict['nsd'])/len(dict['nsd']))
77
+ merge_ls.append(' / '.join(dict['merge']))
78
+
79
+ # find region
80
+ if label in region2label['abnormal']:
81
+ region_dice_ls['abnormal'].append(dice_ls[-1])
82
+ region_nsd_ls['abnormal'].append(nsd_ls[-1])
83
+ region_merge_ls['abnormal'].append(mod_lab)
84
+ region_ls.append('abnormal')
85
+ else:
86
+ found = False
87
+ for region, labels_in_region in region2label.items():
88
+ if label in labels_in_region:
89
+ region_dice_ls[region].append(dice_ls[-1])
90
+ region_nsd_ls[region].append(nsd_ls[-1])
91
+ region_merge_ls[region].append(mod_lab)
92
+ region_ls.append(region)
93
+ found = True
94
+ break
95
+ if not found:
96
+ print(label)
97
+ region_ls.append('unknown')
98
+
99
+ df = pd.DataFrame({
100
+ 'Modality_Label': mod_lab_ls,
101
+ 'Dice': dice_ls,
102
+ 'NSD': nsd_ls,
103
+ 'Merge': merge_ls,
104
+ 'Region': region_ls
105
+ })
106
+
107
+ #book = openpyxl.load_workbook(xlsx2save)
108
+ #writer = pd.ExcelWriter(xlsx2save, engine='openpyxl')
109
+ #writer.book = book
110
+
111
+ # with pd.ExcelWriter(xlsx2save, engine='openpyxl', mode='a', if_sheet_exists='new') as writer:
112
+ # df.to_excel(writer, sheet_name='Label Merge', index=False)
113
+
114
+ # 写上anno num和repeat ratio
115
+ with open(mod_label_statistic, 'r') as f:
116
+ statistic_dict = json.load(f)
117
+
118
+ # 将Label Merged DataFrame写入新的工作表
119
+ new_sheet = workbook.create_sheet(title='Label Merge', index=1)
120
+ new_sheet.cell(row=1, column=1, value='Modality_Label')
121
+ new_sheet.cell(row=1, column=2, value='Dice')
122
+ new_sheet.cell(row=1, column=3, value='NSD')
123
+ new_sheet.cell(row=1, column=4, value='Merge')
124
+ new_sheet.cell(row=1, column=5, value='Region')
125
+ new_sheet.cell(row=1, column=6, value='Total_Num')
126
+ new_sheet.cell(row=1, column=7, value='Aug_Ratio')
127
+ row = 2
128
+ for mod_lab, dice, nsd, merge, region in zip(mod_lab_ls, dice_ls, nsd_ls, merge_ls, region_ls):
129
+ if mod_lab in statistic_dict:
130
+ _, total_num, aug_ratio = statistic_dict[mod_lab]
131
+ else:
132
+ total_num = aug_ratio = 0
133
+ new_sheet.cell(row=row, column=1, value=mod_lab)
134
+ new_sheet.cell(row=row, column=2, value=dice)
135
+ new_sheet.cell(row=row, column=3, value=nsd)
136
+ new_sheet.cell(row=row, column=4, value=merge)
137
+ new_sheet.cell(row=row, column=5, value=region)
138
+ new_sheet.cell(row=row, column=6, value=total_num)
139
+ new_sheet.cell(row=row, column=7, value=aug_ratio)
140
+ row += 1
141
+ new_sheet.cell(row=row, column=2, value=sum(dice_ls)/len(dice_ls)) # avg over all labels
142
+ new_sheet.cell(row=row, column=3, value=sum(nsd_ls)/len(nsd_ls))
143
+
144
+ # 将Region Merged 写入新的工作表
145
+ new_sheet = workbook.create_sheet(title='Region Merge', index=1)
146
+ new_sheet.cell(row=1, column=1, value='Region')
147
+ new_sheet.cell(row=1, column=2, value='Dice')
148
+ new_sheet.cell(row=1, column=3, value='NSD')
149
+ new_sheet.cell(row=1, column=4, value='Merge')
150
+ row = 2
151
+ for key in region_dice_ls.keys():
152
+ if len(region_dice_ls[key]) == 0:
153
+ dice = nsd = 0
154
+ merge = None
155
+ else:
156
+ dice = sum(region_dice_ls[key])/len(region_dice_ls[key])
157
+ nsd = sum(region_nsd_ls[key])/len(region_nsd_ls[key])
158
+ merge = ','.join(region_merge_ls[key])
159
+ class_name = f'{key}({len(region_dice_ls[key])})'
160
+ new_sheet.cell(row=row, column=1, value=class_name)
161
+ new_sheet.cell(row=row, column=2, value=dice)
162
+ new_sheet.cell(row=row, column=3, value=nsd)
163
+ new_sheet.cell(row=row, column=4, value=merge)
164
+ row += 1
165
+
166
+ workbook.save(xlsx2save)
167
+
168
+ # 返回所有 label 的 avg
169
+ avg_dice_over_merged_labels = sum(dice_ls) / len(dice_ls)
170
+ avg_nsd_over_merged_labels = sum(nsd_ls) / len(nsd_ls)
171
+
172
+ return avg_dice_over_merged_labels, avg_nsd_over_merged_labels
173
+
174
+ if __name__ == '__main__':
175
+ import argparse
176
+
177
+ def str2bool(v):
178
+ if isinstance(v, bool):
179
+ return v
180
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
181
+ return True
182
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
183
+ return False
184
+ else:
185
+ raise argparse.ArgumentTypeError('Boolean value expected.')
186
+
187
+ parser = argparse.ArgumentParser()
188
+ parser.add_argument('--xlsx2load', type=str)
189
+ parser.add_argument('--xlsx2save', type=str)
190
+ parser.add_argument('--mod_lab_json', type=str, default='/mnt/petrelfs/share_data/wuchaoyi/SAM/processed_files_v4/mod_lab(72).json')
191
+ parser.add_argument('--mod_label_statistic', type=str, default='/mnt/petrelfs/share_data/wuchaoyi/SAM/processed_files_v4/mod_lab_accum_statis(49).json')
192
+
193
+ config = parser.parse_args()
194
+
195
+ if not config.xlsx2save:
196
+ config.xlsx2save = config.xlsx2load
197
+
198
+ merge(config.mod_lab_json, config.mod_label_statistic, config.xlsx2load, config.xlsx2save)
evaluate/metric.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import time
4
+ from medpy import metric
5
+ from .SurfaceDice import compute_surface_distances, compute_surface_dice_at_tolerance
6
+
7
+ def calculate_metric_percase(pred, gt, dice=True, nsd=True):
8
+ pred = pred.astype(bool)
9
+ gt = gt.astype(bool)
10
+
11
+ metrics = {}
12
+
13
+ if np.sum(gt) == 0.0:
14
+ if np.sum(pred) == 0.0:
15
+ if dice:
16
+ metrics['dice'] = 1.0
17
+ if nsd:
18
+ metrics['nsd'] = 1.0
19
+ else:
20
+ if dice:
21
+ metrics['dice'] = 0.0
22
+ if nsd:
23
+ metrics['nsd'] = 0.0
24
+ return metrics
25
+
26
+ if dice:
27
+ dice_score = metric.binary.dc(pred, gt)
28
+ metrics['dice'] = dice_score
29
+
30
+ if nsd:
31
+ surface_distances = compute_surface_distances(gt, pred, [1, 1, 3])
32
+ nsd_score = compute_surface_dice_at_tolerance(surface_distances, 1)
33
+ metrics['nsd'] = nsd_score
34
+
35
+ return metrics
36
+
37
+ if __name__ == '__main__':
38
+ pred = torch.zeros((3, 256, 256, 16)).numpy()
39
+ pred[:, 0:128, 0:128, :] = 1.0
40
+ gt = torch.zeros((3, 256, 256, 16)).numpy()
41
+ gt[:, 0:64, 0:64, :] = 1.0
42
+ dice = calculate_metric_percase(pred, gt)['dice']
43
+ print(dice)
44
+
45
+
46
+
evaluate/params.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ def str2bool(v):
4
+ return v.lower() in ('true', 't')
5
+
6
+ def parse_args():
7
+ parser = argparse.ArgumentParser()
8
+
9
+ # Exp Controller
10
+
11
+ parser.add_argument(
12
+ "--rcd_dir",
13
+ type=str,
14
+ help="save the evaluation results (in a directory)",
15
+ )
16
+ parser.add_argument(
17
+ "--rcd_file",
18
+ type=str,
19
+ help="save the evaluation results (in a csv/xlsx file)",
20
+ )
21
+ parser.add_argument(
22
+ "--visualization",
23
+ type=str2bool,
24
+ default=False,
25
+ help="save the visualization for each case (img, gt, pred)",
26
+ )
27
+ parser.add_argument(
28
+ "--checkpoint",
29
+ type=str,
30
+ help="Checkpoint path",
31
+ )
32
+ parser.add_argument(
33
+ "--partial_load",
34
+ type=str2bool,
35
+ default=True,
36
+ help="Allow to load partial paramters from checkpoint",
37
+ )
38
+ parser.add_argument(
39
+ "--gpu",
40
+ type=str,
41
+ default=None,
42
+ )
43
+ parser.add_argument(
44
+ "--resume",
45
+ type=str2bool,
46
+ default=True,
47
+ help="Inherit medial results from an interrupted evaluation (no harm even if you evaluate from scratch)",
48
+ )
49
+ parser.add_argument(
50
+ "--save_interval",
51
+ type=int,
52
+ default=100
53
+ )
54
+
55
+ # Metrics
56
+
57
+ parser.add_argument(
58
+ "--dice",
59
+ type=str2bool,
60
+ default=True,
61
+ )
62
+ parser.add_argument(
63
+ "--nsd",
64
+ type=str2bool,
65
+ default=True,
66
+ )
67
+
68
+ # Med SAM Dataset
69
+
70
+ parser.add_argument(
71
+ "--datasets_jsonl",
72
+ type=str,
73
+ )
74
+ parser.add_argument(
75
+ "--text_prompts_json",
76
+ type=str,
77
+ help='This is needed for CVPR25 challenge, where multiple prompts (synonyms) are required.'
78
+ )
79
+
80
+ # Sampler and Loader
81
+
82
+ parser.add_argument(
83
+ "--online_crop",
84
+ type=str2bool,
85
+ default='False',
86
+ help='load pre-cropped image patches directly, or crop online',
87
+ )
88
+ parser.add_argument(
89
+ "--crop_size",
90
+ type=int,
91
+ nargs='+',
92
+ default=[288, 288, 96],
93
+ )
94
+ parser.add_argument(
95
+ "--max_queries",
96
+ type=int,
97
+ default=256,
98
+ )
99
+ parser.add_argument(
100
+ "--batchsize_3d",
101
+ type=int,
102
+ default=2,
103
+ )
104
+ parser.add_argument(
105
+ "--pin_memory",
106
+ type=str2bool,
107
+ default=False,
108
+ help='load data to gpu to accelerate'
109
+ )
110
+ parser.add_argument(
111
+ "--num_workers",
112
+ type=int,
113
+ default=4
114
+ )
115
+
116
+ # Knowledge Encoder
117
+ parser.add_argument(
118
+ "--text_encoder_partial_load",
119
+ type=str2bool,
120
+ default=True,
121
+ help="Allow to load partial paramters from checkpoint",
122
+ )
123
+ parser.add_argument(
124
+ "--text_encoder_checkpoint",
125
+ type=str,
126
+ )
127
+ parser.add_argument(
128
+ "--text_encoder",
129
+ type=str,
130
+ )
131
+
132
+ # MaskFormer
133
+
134
+ parser.add_argument(
135
+ "--vision_backbone",
136
+ type=str,
137
+ help='UNET or UNET-H'
138
+ )
139
+ parser.add_argument(
140
+ "--patch_size",
141
+ type=int,
142
+ nargs='+',
143
+ default=[32, 32, 32],
144
+ help='patch size on h w and d'
145
+ )
146
+ parser.add_argument(
147
+ "--deep_supervision",
148
+ type=str2bool,
149
+ default=False,
150
+ )
151
+
152
+ args = parser.parse_args()
153
+ return args
inference_medals_nifti.py ADDED
@@ -0,0 +1,1885 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Medal-S inference script for generic raw image segmentation.
3
+
4
+ This script provides an interface for running Medal-S inference
5
+ on raw NIfTI images. It supports both single-stage (Stage 2 only) and
6
+ two-stage (Stage 1 + Stage 2) inference modes.
7
+
8
+ Usage:
9
+ python inference_medals.py --input input.nii.gz --output output.nii.gz \\
10
+ --modality CT --texts "Aorta observed in abdominal CT scans" --labels 1
11
+
12
+ # Or use JSON configuration file:
13
+ python inference_medals.py --input input.nii.gz --output output.nii.gz \\
14
+ --config config.json --mode stage1+stage2
15
+
16
+ Author: Pengcheng Shi
17
+ Institute: Medical Image Insights, Inc., Shanghai, China
18
+ Email: shipc1220@gmail.com
19
+ License: Apache License 2.0
20
+ """
21
+
22
+ import os
23
+ import argparse
24
+ import json
25
+ import time
26
+ import math
27
+ import random
28
+ import itertools
29
+ import gc
30
+ import numpy as np
31
+ import SimpleITK as sitk
32
+ import torch
33
+ import torch.nn.functional as F
34
+ from typing import List
35
+ from scipy.ndimage import label, gaussian_filter
36
+ from einops import rearrange
37
+ from tqdm import tqdm
38
+ from torch.cuda.amp import autocast
39
+
40
+ from data.default_resampling import resample_data_or_seg, compute_new_shape, resample_data_or_seg_to_spacing
41
+ from data.resample_torch import resample_torch_fornnunet, resample_torch_simple
42
+ from model.maskformer import Maskformer
43
+ from model.knowledge_encoder import Knowledge_Encoder
44
+
45
+ def adjust_spacing(img_array, img_spacing):
46
+ """
47
+ Adjust spacing based on image dimensions.
48
+
49
+ This function swaps spacing values if the dimension with minimum size
50
+ doesn't match the dimension with maximum spacing.
51
+
52
+ Args:
53
+ img_array: Image array (used for shape reference)
54
+ img_spacing: Spacing array
55
+
56
+ Returns:
57
+ Adjusted spacing array
58
+ """
59
+ img_spacing = np.asarray(img_spacing)
60
+ min_dim_index = np.argmin(img_array.shape)
61
+ max_spacing_index = np.argmax(img_spacing)
62
+
63
+ if (min_dim_index != max_spacing_index) and (img_spacing[max_spacing_index] > 0.5):
64
+ new_order = list(range(len(img_spacing)))
65
+ new_order[min_dim_index], new_order[max_spacing_index] = new_order[max_spacing_index], new_order[min_dim_index]
66
+ img_spacing = img_spacing[new_order]
67
+
68
+ return img_spacing
69
+
70
+
71
+ def remove_small_objects_binary(binary_data, min_size=10):
72
+ """
73
+ Remove small objects from binary data.
74
+
75
+ Args:
76
+ binary_data: Binary array
77
+ min_size: Minimum size threshold for objects to keep
78
+
79
+ Returns:
80
+ Binary array with small objects removed
81
+ """
82
+ labeled_array, num_features = label(binary_data)
83
+ sizes = np.bincount(labeled_array.ravel())
84
+ remove = sizes < min_size
85
+ remove[0] = False # Ensure the background (label 0) is not removed
86
+ labeled_array[remove[labeled_array]] = 0
87
+ return labeled_array > 0
88
+
89
+
90
+ def respace_image(image: np.ndarray, current_spacing: np.ndarray, target_spacing: np.ndarray, device: torch.device) -> np.ndarray:
91
+ """
92
+ Resample image to target spacing.
93
+
94
+ Args:
95
+ image: Input image array with shape (C, H, W, D)
96
+ current_spacing: Current spacing array
97
+ target_spacing: Target spacing array
98
+ device: PyTorch device for resampling
99
+
100
+ Returns:
101
+ Resampled image array
102
+ """
103
+ new_shape = compute_new_shape(image.shape[1:], current_spacing, target_spacing)
104
+ resampled_image = resample_torch_fornnunet(
105
+ image, new_shape, current_spacing, target_spacing,
106
+ is_seg=False, num_threads=8, device=device,
107
+ memefficient_seg_resampling=False,
108
+ force_separate_z=None,
109
+ separate_z_anisotropy_threshold=3.0
110
+ )
111
+ return resampled_image
112
+
113
+
114
+ def respace_mask(mask: np.ndarray, current_spacing: np.ndarray, target_spacing: np.ndarray, device: torch.device) -> np.ndarray:
115
+ """
116
+ Resample mask to target spacing.
117
+
118
+ Args:
119
+ mask: Input mask array with shape (C, H, W, D)
120
+ current_spacing: Current spacing array
121
+ target_spacing: Target spacing array
122
+ device: PyTorch device for resampling
123
+
124
+ Returns:
125
+ Resampled mask array
126
+ """
127
+ new_shape = compute_new_shape(mask.shape[1:], current_spacing, target_spacing)
128
+ resampled_mask = resample_torch_fornnunet(
129
+ mask, new_shape, current_spacing, target_spacing,
130
+ is_seg=True, num_threads=8, device=device,
131
+ memefficient_seg_resampling=False,
132
+ force_separate_z=None,
133
+ separate_z_anisotropy_threshold=3.0
134
+ )
135
+ return resampled_mask
136
+
137
+
138
+ def split_3d(image_tensor, crop_size=[288, 288, 96]):
139
+ """
140
+ Split 3D image into overlapping patches.
141
+
142
+ Patches are extracted with 50% overlap (stride = crop_size / 2) to ensure
143
+ complete coverage of the image volume.
144
+
145
+ Args:
146
+ image_tensor: Input image tensor with shape (C, H, W, D)
147
+ crop_size: Size of each patch [h, w, d]
148
+
149
+ Returns:
150
+ split_patch: List of patch tensors
151
+ split_idx: List of patch indices [h_s, h_e, w_s, w_e, d_s, d_e]
152
+ """
153
+ interval_h, interval_w, interval_d = crop_size[0] // 2, crop_size[1] // 2, crop_size[2] // 2
154
+ split_idx = []
155
+ split_patch = []
156
+
157
+ c, h, w, d = image_tensor.shape
158
+ h_crop = max(math.ceil(h / interval_h) - 1, 1)
159
+ w_crop = max(math.ceil(w / interval_w) - 1, 1)
160
+ d_crop = max(math.ceil(d / interval_d) - 1, 1)
161
+
162
+ for i in range(h_crop):
163
+ h_s = i * interval_h
164
+ h_e = h_s + crop_size[0]
165
+ if h_e > h:
166
+ h_s = h - crop_size[0]
167
+ h_e = h
168
+ if h_s < 0:
169
+ h_s = 0
170
+ for j in range(w_crop):
171
+ w_s = j * interval_w
172
+ w_e = w_s + crop_size[1]
173
+ if w_e > w:
174
+ w_s = w - crop_size[1]
175
+ w_e = w
176
+ if w_s < 0:
177
+ w_s = 0
178
+ for k in range(d_crop):
179
+ d_s = k * interval_d
180
+ d_e = d_s + crop_size[2]
181
+ if d_e > d:
182
+ d_s = d - crop_size[2]
183
+ d_e = d
184
+ if d_s < 0:
185
+ d_s = 0
186
+ split_idx.append([h_s, h_e, w_s, w_e, d_s, d_e])
187
+ split_patch.append(image_tensor[:, h_s:h_e, w_s:w_e, d_s:d_e])
188
+
189
+ return split_patch, split_idx
190
+
191
+
192
+ def pad_if_necessary(image, crop_size=[288, 288, 96]):
193
+ """
194
+ Pad image if necessary to meet crop size requirements.
195
+
196
+ Args:
197
+ image: Input image tensor with shape (C, H, W, D)
198
+ crop_size: Minimum size requirements [h, w, d]
199
+
200
+ Returns:
201
+ padded_image: Padded image tensor
202
+ padding_info: Tuple of padding amounts (pad_h, pad_w, pad_d)
203
+ """
204
+ c, h, w, d = image.shape
205
+ croph, cropw, cropd = crop_size
206
+ pad_in_h = 0 if h >= croph else croph - h
207
+ pad_in_w = 0 if w >= cropw else cropw - w
208
+ pad_in_d = 0 if d >= cropd else cropd - d
209
+
210
+ padding_info = (pad_in_h, pad_in_w, pad_in_d)
211
+
212
+ if pad_in_h + pad_in_w + pad_in_d > 0:
213
+ pad = (0, pad_in_d, 0, pad_in_w, 0, pad_in_h)
214
+ image = F.pad(image, pad, 'constant', 0)
215
+
216
+ return image, padding_info
217
+
218
+
219
+ def remove_padding(padded_image, padding_info):
220
+ """
221
+ Remove padding from image.
222
+
223
+ Args:
224
+ padded_image: Padded image (can be torch.Tensor or numpy array)
225
+ padding_info: Tuple of padding amounts (pad_h, pad_w, pad_d)
226
+
227
+ Returns:
228
+ Image with padding removed
229
+ """
230
+ pad_in_h, pad_in_w, pad_in_d = padding_info
231
+
232
+ if len(padded_image.shape) == 4:
233
+ if isinstance(padded_image, torch.Tensor):
234
+ return padded_image[:, :padded_image.shape[1]-pad_in_h, :padded_image.shape[2]-pad_in_w, :padded_image.shape[3]-pad_in_d]
235
+ else:
236
+ return padded_image[:, :padded_image.shape[1]-pad_in_h, :padded_image.shape[2]-pad_in_w, :padded_image.shape[3]-pad_in_d]
237
+ else:
238
+ if isinstance(padded_image, torch.Tensor):
239
+ return padded_image[:padded_image.shape[0]-pad_in_h, :padded_image.shape[1]-pad_in_w, :padded_image.shape[2]-pad_in_d]
240
+ else:
241
+ return padded_image[:padded_image.shape[0]-pad_in_h, :padded_image.shape[1]-pad_in_w, :padded_image.shape[2]-pad_in_d]
242
+
243
+
244
+ def internal_maybe_mirror_and_predict(model=None, queries=None, image_input=None, simulated_lowres_sc_pred=None,
245
+ simulated_lowres_mc_pred=None, mirror_axes=(0, 1, 2)):
246
+ """
247
+ Apply test-time augmentation with mirroring.
248
+
249
+ This function performs inference with multiple mirroring combinations
250
+ and averages the results for improved robustness.
251
+
252
+ Args:
253
+ model: Model to use for prediction
254
+ queries: Query tensor
255
+ image_input: Input image tensor
256
+ simulated_lowres_sc_pred: Simulated low-res single-channel prediction
257
+ simulated_lowres_mc_pred: Simulated low-res multi-channel prediction
258
+ mirror_axes: Axes to mirror (0, 1, 2 for spatial dimensions)
259
+
260
+ Returns:
261
+ Averaged prediction tensor
262
+ """
263
+ prediction = model(queries=queries,
264
+ image_input=image_input,
265
+ simulated_lowres_sc_pred=simulated_lowres_sc_pred,
266
+ simulated_lowres_mc_pred=simulated_lowres_mc_pred,
267
+ train_mode=False)
268
+
269
+ if mirror_axes is not None:
270
+ assert max(mirror_axes) <= image_input.ndim - 3, 'mirror_axes does not match the dimension of the input!'
271
+ mirror_axes = [m + 2 for m in mirror_axes]
272
+ axes_combinations = [
273
+ c for i in range(len(mirror_axes)) for c in itertools.combinations(mirror_axes, i + 1)
274
+ ]
275
+ for axes in axes_combinations:
276
+ image_input_fliped = torch.flip(image_input, axes)
277
+ simulated_lowres_sc_pred_fliped = torch.flip(simulated_lowres_sc_pred.unsqueeze(0), axes).squeeze(0) if simulated_lowres_sc_pred is not None else None
278
+ simulated_lowres_mc_pred_fliped = torch.flip(simulated_lowres_mc_pred.unsqueeze(0), axes).squeeze(0) if simulated_lowres_mc_pred is not None else None
279
+ prediction_fliped = model(queries=queries,
280
+ image_input=image_input_fliped,
281
+ simulated_lowres_sc_pred=simulated_lowres_sc_pred_fliped,
282
+ simulated_lowres_mc_pred=simulated_lowres_mc_pred_fliped,
283
+ train_mode=False)
284
+ prediction += torch.flip(prediction_fliped, axes)
285
+ prediction /= (len(axes_combinations) + 1)
286
+ return prediction
287
+
288
+
289
+ def compute_patch_prediction(
290
+ queries: torch.Tensor,
291
+ patches: torch.Tensor,
292
+ lowres_single_channel_pred: torch.Tensor,
293
+ lowres_multi_channel_pred: torch.Tensor,
294
+ model: torch.nn.Module,
295
+ possible_block_sizes: List[int],
296
+ n_repeats: int = 1,
297
+ disable_tta: bool = True
298
+ ) -> torch.Tensor:
299
+ """
300
+ Compute patch predictions using complementary masking.
301
+
302
+ This function splits the volume into blocks, processes complementary halves
303
+ using random masks, and combines results. The process is repeated n_repeats
304
+ times with different random masks, and results are averaged.
305
+
306
+ Args:
307
+ queries: Input query tensor, shape (batch, query_dim)
308
+ patches: Image patch tensor, shape (batch, channels, h, w, d)
309
+ lowres_single_channel_pred: Low-res single-channel prediction, shape (1, 1, h, w, d)
310
+ lowres_multi_channel_pred: Low-res multi-channel prediction, shape (1, c, h, w, d)
311
+ model: Trained neural network model
312
+ possible_block_sizes: List of possible block sizes (e.g., [8, 16, 32])
313
+ n_repeats: Number of times to repeat prediction with different masks
314
+ disable_tta: Whether to disable test-time augmentation
315
+
316
+ Returns:
317
+ Averaged patch prediction, shape (1, c, h, w, d)
318
+ """
319
+ # Validate inputs
320
+ if not possible_block_sizes:
321
+ raise ValueError("possible_block_sizes cannot be empty")
322
+ if n_repeats < 1:
323
+ raise ValueError("n_repeats must be at least 1")
324
+
325
+ _, _, h, w, d = lowres_single_channel_pred.shape
326
+ device = lowres_single_channel_pred.device
327
+ prediction_sum = torch.zeros_like(lowres_multi_channel_pred, device=device)
328
+
329
+ def upsample_block_mask(block_mask: torch.Tensor, block_size: int) -> torch.Tensor:
330
+ """Upsample a block mask to full resolution."""
331
+ upsampled = (
332
+ block_mask.unsqueeze(0).unsqueeze(0)
333
+ .repeat_interleave(block_size, dim=2)
334
+ .repeat_interleave(block_size, dim=3)
335
+ .repeat_interleave(block_size, dim=4)
336
+ [:, :, :h, :w, :d]
337
+ ).float()
338
+ return upsampled
339
+
340
+ for _ in range(n_repeats):
341
+ block_size = random.choice(possible_block_sizes)
342
+ n_blocks_h = (h + block_size - 1) // block_size
343
+ n_blocks_w = (w + block_size - 1) // block_size
344
+ n_blocks_d = (d + block_size - 1) // block_size
345
+ total_blocks = n_blocks_h * n_blocks_w * n_blocks_d
346
+
347
+ num_selected = max(1, total_blocks // 2)
348
+ block_mask = torch.zeros(n_blocks_h, n_blocks_w, n_blocks_d, dtype=torch.bool, device=device)
349
+ indices = torch.randperm(total_blocks, device=device)[:num_selected]
350
+ block_mask.view(-1)[indices] = True
351
+
352
+ mask = upsample_block_mask(block_mask, block_size)
353
+ complementary_mask = 1.0 - mask
354
+
355
+ masked_sc_pred = lowres_single_channel_pred * mask
356
+ masked_mc_pred = lowres_multi_channel_pred * mask
357
+
358
+ if disable_tta:
359
+ first_half_pred = model(
360
+ queries=queries,
361
+ image_input=patches,
362
+ simulated_lowres_sc_pred=masked_sc_pred,
363
+ simulated_lowres_mc_pred=masked_mc_pred,
364
+ train_mode=False
365
+ )
366
+ else:
367
+ first_half_pred = internal_maybe_mirror_and_predict(
368
+ model=model,
369
+ queries=queries,
370
+ image_input=patches,
371
+ simulated_lowres_sc_pred=masked_sc_pred,
372
+ simulated_lowres_mc_pred=masked_mc_pred,
373
+ mirror_axes=(0, 1, 2)
374
+ )
375
+
376
+ masked_sc_pred_comp = lowres_single_channel_pred * complementary_mask
377
+ masked_mc_pred_comp = lowres_multi_channel_pred * complementary_mask
378
+
379
+ if disable_tta:
380
+ second_half_pred = model(
381
+ queries=queries,
382
+ image_input=patches,
383
+ simulated_lowres_sc_pred=masked_sc_pred_comp,
384
+ simulated_lowres_mc_pred=masked_mc_pred_comp,
385
+ train_mode=False
386
+ )
387
+ else:
388
+ second_half_pred = internal_maybe_mirror_and_predict(
389
+ model=model,
390
+ queries=queries,
391
+ image_input=patches,
392
+ simulated_lowres_sc_pred=masked_sc_pred_comp,
393
+ simulated_lowres_mc_pred=masked_mc_pred_comp,
394
+ mirror_axes=(0, 1, 2)
395
+ )
396
+
397
+ final_prediction = first_half_pred * complementary_mask + second_half_pred * mask
398
+ prediction_sum += final_prediction
399
+
400
+ return prediction_sum / n_repeats
401
+
402
+
403
+ def read_npz_data(raw_image, raw_spacing, crop_size=[288, 288, 96],
404
+ target_spacing=[1.5, 1.5, 3.0], scaled_roi_lowres_pred_array=None,
405
+ class_name_list=[], stage_1_flag=False, device=torch.device("cuda", 0), verbose=True):
406
+ """
407
+ Read and preprocess image data for inference.
408
+
409
+ This function handles spacing adjustments, image resampling, padding,
410
+ and patch splitting for the inference pipeline.
411
+
412
+ Args:
413
+ raw_image: Input image array with shape (d, h, w)
414
+ raw_spacing: Spacing array with shape (3,)
415
+ crop_size: Target crop size [h, w, d]
416
+ target_spacing: Target spacing [h, w, d]
417
+ scaled_roi_lowres_pred_array: Optional low-res prediction for ROI-based inference
418
+ class_name_list: List of class names (kept for compatibility, not used)
419
+ stage_1_flag: Whether this is Stage 1 inference (kept for compatibility, not used)
420
+ device: PyTorch device for resampling
421
+ verbose: Whether to print detailed information (default: True)
422
+
423
+ Returns:
424
+ data_dict: Dictionary containing preprocessed patches and metadata
425
+ """
426
+ raw_d, raw_h, raw_w = raw_image.shape
427
+ image = rearrange(raw_image, 'd h w -> h w d')
428
+ spacing = raw_spacing.astype(np.float32)
429
+
430
+ # Simplified spacing adjustment following the provided steps
431
+ # Step 1: Handle very small spacing values
432
+ for i in range(3):
433
+ if spacing[i] <= 0.1:
434
+ spacing[i] = 1.0
435
+
436
+ # Step 2: Adjust spacing based on image dimensions
437
+ spacing = adjust_spacing(image, spacing)
438
+
439
+ # Step 3: Initialize parameters for spacing adjustment
440
+ max_dims = [1000, 1000, 700]
441
+ min_dims = crop_size
442
+ thresholds = []
443
+ current = 1.25
444
+ while current <= 50:
445
+ thresholds.append(current)
446
+ current *= 1.25
447
+ raw_target_spacing = target_spacing.copy()
448
+
449
+ # Step 4: Adjust spacing based on constraints
450
+ for i in range(3):
451
+ # If spacing is less than 1.0 and image dimension is within max_dims, set to 1.0
452
+ if spacing[i] < 1.0 and image.shape[i] <= max_dims[i]:
453
+ spacing[i] = 1.0 # second stage model resolution
454
+
455
+ # If physical dimension exceeds max_dims and spacing is greater than target, use target spacing
456
+ if spacing[i] * image.shape[i] > max_dims[i] * target_spacing[i] and spacing[i] > target_spacing[i]:
457
+ spacing[i] = target_spacing[i]
458
+ # If physical dimension is less than min_dims threshold, adjust target_spacing
459
+ elif spacing[i] * image.shape[i] < min_dims[i] * target_spacing[i]:
460
+ alpha_spacing = 1
461
+ for threshold in reversed(thresholds):
462
+ if image.shape[i] <= (min_dims[i] / threshold):
463
+ alpha_spacing = threshold
464
+ break
465
+
466
+ raw_target_spacing[i] = target_spacing[i]
467
+ target_spacing[i] = max(spacing[i] * image.shape[i] / min_dims[i], spacing[i] / alpha_spacing)
468
+ if verbose:
469
+ print("alpha_spacing: ", alpha_spacing)
470
+ print("spacing[i] * image.shape[i] / min_dims[i], spacing[i] / alpha_spacing: ", spacing[i] * image.shape[i] / min_dims[i], spacing[i] / alpha_spacing)
471
+ print("raw_target_spacing[i], target_spacing[i]: ", raw_target_spacing[i], target_spacing[i])
472
+ target_spacing[i] = min(raw_target_spacing[i], target_spacing[i])
473
+ if verbose:
474
+ print("image.shape[i], min_dims[i], target_spacing[i], spacing[i]: ", image.shape[i], min_dims[i], target_spacing[i], spacing[i])
475
+
476
+ # Set default num_iterations (no special class handling)
477
+ num_iterations = 1
478
+
479
+ image = image[np.newaxis, ...].astype(np.float32)
480
+ if verbose:
481
+ print("image.shape: ", image.shape)
482
+ print("spacing: ", spacing)
483
+ print("target_spacing: ", target_spacing)
484
+ image = respace_image(image, spacing, target_spacing, torch.device('cpu'))
485
+ if verbose:
486
+ print("respace image.shape: ", image.shape)
487
+ image = torch.tensor(image)
488
+ image, padding_info = pad_if_necessary(image, crop_size=crop_size)
489
+ _, h, w, d = image.shape
490
+
491
+ patches, y1y2_x1x2_z1z2_ls = split_3d(image, crop_size=crop_size)
492
+
493
+ data_dict = {
494
+ 'spacing': spacing,
495
+ 'original_shape': (raw_h, raw_w, raw_d),
496
+ 'current_shape': (h, w, d),
497
+ 'patches': patches,
498
+ 'y1y2_x1x2_z1z2_ls': y1y2_x1x2_z1z2_ls,
499
+ 'padding_info': padding_info,
500
+ 'raw_image': raw_image,
501
+ 'num_iterations': num_iterations
502
+ }
503
+
504
+ if scaled_roi_lowres_pred_array is not None:
505
+ lowres_pred = rearrange(scaled_roi_lowres_pred_array, 'd h w -> h w d')
506
+ lowres_pred = lowres_pred[np.newaxis, ...].astype(np.float32)
507
+ lowres_pred = respace_mask(lowres_pred, spacing, target_spacing, torch.device('cpu'))
508
+ lowres_pred = torch.tensor(lowres_pred)
509
+ lowres_pred, padding_info = pad_if_necessary(lowres_pred, crop_size=crop_size)
510
+ lowres_pred_patches, _ = split_3d(lowres_pred, crop_size=crop_size)
511
+ data_dict['lowres_pred_patches'] = lowres_pred_patches
512
+ data_dict['padding_info'] = padding_info
513
+
514
+ return data_dict
515
+
516
+
517
+ def compute_gaussian(tile_size, sigma_scale: float = 1. / 8, value_scaling_factor: float = 10, dtype=np.float16):
518
+ """
519
+ Compute Gaussian importance map for patch weighting.
520
+
521
+ This creates a Gaussian weight map centered at the patch center, used for
522
+ weighted averaging of overlapping patch predictions.
523
+
524
+ Args:
525
+ tile_size: Size of the tile (crop_size)
526
+ sigma_scale: Scale factor for Gaussian sigma (relative to tile size)
527
+ value_scaling_factor: Scaling factor for the Gaussian values
528
+ dtype: Data type for the output array
529
+
530
+ Returns:
531
+ Gaussian importance map array
532
+ """
533
+ tmp = np.zeros(tile_size)
534
+ center_coords = [i // 2 for i in tile_size]
535
+ sigmas = [i * sigma_scale for i in tile_size]
536
+ tmp[tuple(center_coords)] = 1
537
+ gaussian_importance_map = gaussian_filter(tmp, sigmas, 0, mode='constant', cval=0)
538
+ gaussian_importance_map = gaussian_importance_map / np.max(gaussian_importance_map) * value_scaling_factor
539
+ gaussian_importance_map = gaussian_importance_map.astype(dtype)
540
+ gaussian_importance_map[gaussian_importance_map == 0] = np.min(
541
+ gaussian_importance_map[gaussian_importance_map != 0])
542
+ return gaussian_importance_map
543
+
544
+
545
+ def sc_mask_to_mc_mask(sc_mask, label_values_ls):
546
+ """
547
+ Convert single-channel mask to multi-channel mask.
548
+
549
+ Args:
550
+ sc_mask: Single-channel mask with shape (1, 1, h, w, d) or (h, w, d)
551
+ label_values_ls: List of label values to create channels for
552
+
553
+ Returns:
554
+ Multi-channel mask with shape (1, n_classes, h, w, d)
555
+ """
556
+ sc_mask = sc_mask.squeeze(0).squeeze(0)
557
+ assert sc_mask.ndim == 3
558
+ h, w, d = sc_mask.shape
559
+ n = len(label_values_ls)
560
+ mc_mask = torch.zeros((n, h, w, d), dtype=bool).to(sc_mask.device)
561
+ for i, label_value in enumerate(label_values_ls):
562
+ mc_mask[i] = torch.where(sc_mask == label_value, 1, 0)
563
+ mc_mask = mc_mask.to(torch.float32)
564
+ mc_mask = mc_mask.unsqueeze(0)
565
+ return mc_mask
566
+
567
+
568
+ class MedicalSegmentationPipeline:
569
+ """
570
+ Pipeline for medical image segmentation.
571
+
572
+ This class handles model loading, data preprocessing, and inference execution
573
+ for the Medal-S segmentation pipeline.
574
+ """
575
+
576
+ def __init__(self, config):
577
+ """
578
+ Initialize the segmentation pipeline.
579
+
580
+ Args:
581
+ config: Dictionary containing pipeline configuration parameters
582
+ """
583
+ self.config = config
584
+ self.device = torch.device(config['device'])
585
+
586
+ def _load_model(self):
587
+ """
588
+ Load vision model and text encoder from checkpoints.
589
+
590
+ Returns:
591
+ model: Loaded vision model (Maskformer)
592
+ text_encoder: Loaded text encoder (Knowledge_Encoder)
593
+ """
594
+ crop_str = '_'.join(map(str, self.config['crop_size']))
595
+ spacing_str = '_'.join(map(str, self.config['target_spacing_model']))
596
+
597
+ vision_backbone_checkpoint = os.path.join(
598
+ self.config['checkpoints_path'],
599
+ f"nano_UNet_CVPR2025_crop_size_{crop_str}_spacing_{spacing_str}_step_{self.config['model_step']}.pth")
600
+
601
+ model = Maskformer(
602
+ self.config['vision_backbone'],
603
+ self.config['input_channels'],
604
+ self.config['crop_size'],
605
+ self.config['patch_size'],
606
+ False
607
+ )
608
+ model = model.to(self.device)
609
+ checkpoint = torch.load(vision_backbone_checkpoint, map_location=self.device)
610
+ new_state_dict = {
611
+ k[7:] if k.startswith('module.') else k: v
612
+ for k, v in checkpoint['model_state_dict'].items()
613
+ if 'mid_mask_embed_proj' not in k
614
+ }
615
+ model.load_state_dict(new_state_dict)
616
+ model.eval()
617
+
618
+ text_encoder = Knowledge_Encoder(
619
+ biolord_checkpoint=os.path.join(
620
+ self.config['checkpoints_path'],
621
+ 'BioLORD-2023-C'
622
+ )
623
+ )
624
+ text_encoder = text_encoder.to(self.device)
625
+ checkpoint = torch.load(
626
+ os.path.join(self.config['checkpoints_path'], 'text_encoder.pth'),
627
+ map_location=self.device
628
+ )
629
+ new_state_dict = {
630
+ k[7:] if k.startswith('module.') else k: v
631
+ for k, v in checkpoint['model_state_dict'].items()
632
+ }
633
+ text_encoder.load_state_dict(new_state_dict, strict=False)
634
+ text_encoder.eval()
635
+
636
+ return model, text_encoder
637
+
638
+ def run_inference(self, raw_image, raw_spacing, verbose=True):
639
+ """
640
+ Run inference on the input image.
641
+
642
+ This method performs the complete inference pipeline:
643
+ 1. Load models (vision backbone and text encoder)
644
+ 2. Preprocess image data (resampling, padding, patch splitting)
645
+ 3. Encode text prompts
646
+ 4. Process patches and aggregate predictions
647
+ 5. Post-process results (remove padding, resample to original shape)
648
+
649
+ Args:
650
+ raw_image: Input image array with shape (d, h, w)
651
+ raw_spacing: Spacing array with shape (3,)
652
+ verbose: Whether to print detailed information (default: True)
653
+
654
+ Returns:
655
+ pred_array: Segmentation array with shape (d, h, w), dtype int16
656
+ max_prob_array: Maximum probability array (if return_max_prob=True), or None
657
+ """
658
+ model, text_encoder = self._load_model()
659
+ pred_array = None
660
+ crop_size = self.config['crop_size']
661
+ disable_tta = self.config['disable_tta']
662
+ instance_label = self.config['instance_label']
663
+ modality = self.config['modality']
664
+ text_prompts = self.config['texts']
665
+ label_values = self.config['label_values']
666
+ return_max_prob = self.config['return_max_prob']
667
+ class_name_list = self.config['class_name_list']
668
+ stage_1_flag = self.config['stage_1_flag']
669
+ with torch.no_grad():
670
+ # Gaussian is kept on CPU, as accumulation will now happen on CPU
671
+ gaussian = torch.tensor(compute_gaussian(tuple(crop_size)), dtype=torch.float32).cpu()
672
+
673
+ data_dict = read_npz_data(
674
+ raw_image=raw_image,
675
+ raw_spacing=raw_spacing,
676
+ crop_size=crop_size,
677
+ target_spacing=self.config['target_spacing'],
678
+ scaled_roi_lowres_pred_array=self.config['scaled_roi_lowres_pred_array'],
679
+ class_name_list=class_name_list,
680
+ stage_1_flag=stage_1_flag,
681
+ device=self.device,
682
+ verbose=verbose
683
+ )
684
+
685
+ spacing = data_dict['spacing']
686
+ original_shape = data_dict['original_shape']
687
+ current_shape = data_dict['current_shape']
688
+ batched_patches = data_dict['patches']
689
+ batched_y1y2_x1x2_z1z2 = data_dict['y1y2_x1x2_z1z2_ls']
690
+ padding_info = data_dict['padding_info']
691
+ raw_image = data_dict['raw_image']
692
+ num_iterations = data_dict['num_iterations']
693
+ batched_lowres_pred_patches = data_dict.get('lowres_pred_patches')
694
+
695
+ modality_code = torch.tensor([{
696
+ 'ct': 0, 'mri': 1, 'us': 2, 'pet': 3, 'microscopy': 4
697
+ }[modality]]).to(self.device) # Keep modality_code on GPU if text_encoder needs it on GPU
698
+
699
+ h, w, d = current_shape
700
+ n_total_classes = len(text_prompts)
701
+
702
+ # Get category batch size from config, default to 24
703
+ category_batch_size = self.config.get('category_batch_size', 24)
704
+ background_threshold = self.config.get('background_threshold', 0.5)
705
+
706
+ # Initialize max_prob and max_class_label_value on CPU to save GPU memory
707
+ max_prob = torch.zeros((h, w, d), dtype=torch.float32, device='cpu')
708
+ max_class_label_value = torch.zeros((h, w, d), dtype=torch.int16, device='cpu')
709
+
710
+ # Process categories in batches to avoid OOM
711
+ category_range = range(0, n_total_classes, category_batch_size)
712
+ pbar = tqdm(category_range, desc="Processing Categories")
713
+ for i in pbar:
714
+ current_category_texts = text_prompts[i:i + category_batch_size]
715
+ current_label_values = label_values[i:i + category_batch_size]
716
+ current_n = len(current_category_texts)
717
+ end_idx = min(i + current_n - 1, n_total_classes - 1)
718
+
719
+ # Update progress bar description with current category range
720
+ pbar.set_description(f"Processing Categories {i}-{end_idx}")
721
+
722
+ # Keep these large tensors on CPU for accumulation
723
+ temp_prediction_batch_cpu = torch.zeros((current_n, h, w, d), dtype=torch.float32, device='cpu')
724
+ temp_accumulation_batch_cpu = torch.zeros((current_n, h, w, d), dtype=torch.float32, device='cpu')
725
+
726
+ # Encode text prompts for current batch
727
+ with autocast(enabled=False):
728
+ queries = text_encoder(current_category_texts, modality_code, self.device) # queries remain on GPU for model input
729
+
730
+ # Process patches for current category batch
731
+ for patches, lowres_pred_patches, y1y2_x1x2_z1z2_ls in tqdm(
732
+ zip(batched_patches, batched_lowres_pred_patches if batched_lowres_pred_patches is not None else [None]*len(batched_patches), batched_y1y2_x1x2_z1z2),
733
+ total=len(batched_patches),
734
+ desc="Processing",
735
+ ncols=100,
736
+ bar_format="{l_bar}{bar:20}{r_bar}",
737
+ colour="green",
738
+ leave=False
739
+ ):
740
+ patches = patches.unsqueeze(0).to(device=self.device, dtype=torch.float32) # patches on GPU for model input
741
+ y1, y2, x1, x2, z1, z2 = y1y2_x1x2_z1z2_ls
742
+
743
+ simulated_lowres_sc_pred = None
744
+ simulated_lowres_mc_pred = None
745
+
746
+ if not self.config['w_lowres_pred_prompts']:
747
+ simulated_lowres_sc_pred = torch.zeros((1, 1, *crop_size), device=self.device, dtype=torch.float32)
748
+ simulated_lowres_mc_pred = torch.zeros((1, current_n, *crop_size), device=self.device, dtype=torch.float32)
749
+ prediction_patch = model(
750
+ queries=queries,
751
+ image_input=patches,
752
+ simulated_lowres_sc_pred=simulated_lowres_sc_pred,
753
+ simulated_lowres_mc_pred=simulated_lowres_mc_pred,
754
+ train_mode=False
755
+ ) if self.config['disable_tta'] else internal_maybe_mirror_and_predict(
756
+ model=model,
757
+ queries=queries,
758
+ image_input=patches,
759
+ simulated_lowres_sc_pred=simulated_lowres_sc_pred,
760
+ simulated_lowres_mc_pred=simulated_lowres_mc_pred,
761
+ mirror_axes=(0, 1, 2)
762
+ )
763
+ else:
764
+ lowres_pred_patches = lowres_pred_patches.unsqueeze(0).to(device=self.device, dtype=torch.float32)
765
+ simulated_lowres_sc_pred = torch.where(lowres_pred_patches > 0, torch.ones_like(lowres_pred_patches), torch.zeros_like(lowres_pred_patches))
766
+ simulated_lowres_mc_pred = sc_mask_to_mc_mask(lowres_pred_patches, [int(val) for val in current_label_values])
767
+
768
+ possible_block_sizes = [8]
769
+ if instance_label == 1:
770
+ n_repeats = 1
771
+ else:
772
+ n_repeats = 1
773
+ prediction_patch = compute_patch_prediction(queries, patches, simulated_lowres_sc_pred, simulated_lowres_mc_pred, model, possible_block_sizes, n_repeats, disable_tta)
774
+
775
+ if instance_label == 1: # Instance segmentation mode
776
+ for _ in range(num_iterations):
777
+ prediction_patch_prob = torch.sigmoid(prediction_patch).detach()
778
+ simulated_lowres_mc_pred = torch.where(prediction_patch_prob > 0.5, 1.0, 0.0)
779
+ simulated_lowres_sc_pred = (simulated_lowres_mc_pred.sum(dim=1, keepdim=True) > 0).float()
780
+ possible_block_sizes = [4]
781
+ n_repeats = 1
782
+ prediction_patch = compute_patch_prediction(queries, patches, simulated_lowres_sc_pred, simulated_lowres_mc_pred, model, possible_block_sizes, n_repeats, disable_tta)
783
+
784
+ prediction_patch_prob_gpu = torch.sigmoid(prediction_patch).detach()
785
+ current_gaussian_slice = gaussian[:y2-y1, :x2-x1, :z2-z1] # Already on CPU
786
+
787
+ # Perform accumulation on CPU. Move prediction_patch_prob_gpu to CPU here.
788
+ temp_prediction_batch_cpu[:, y1:y2, x1:x2, z1:z2] += (prediction_patch_prob_gpu[0, :, :y2-y1, :x2-x1, :z2-z1].cpu() * current_gaussian_slice)
789
+ temp_accumulation_batch_cpu[:, y1:y2, x1:x2, z1:z2] += current_gaussian_slice
790
+
791
+ # Explicitly delete GPU tensors to free up memory immediately
792
+ del prediction_patch, prediction_patch_prob_gpu, patches
793
+ if simulated_lowres_sc_pred is not None:
794
+ del simulated_lowres_sc_pred
795
+ if simulated_lowres_mc_pred is not None:
796
+ del simulated_lowres_mc_pred
797
+ torch.cuda.empty_cache() # Clear any cached GPU memory after each patch processing
798
+ gc.collect() # Python garbage collection
799
+
800
+ # Normalize predictions by accumulation
801
+ batch_accumulation_cpu = temp_accumulation_batch_cpu
802
+ batch_accumulation_cpu[batch_accumulation_cpu == 0] = 1e-8
803
+ batch_prediction_prob_cpu = temp_prediction_batch_cpu / batch_accumulation_cpu
804
+
805
+ # Update max_prob and max_class_label_value on CPU
806
+ for j in range(current_n):
807
+ class_prob_cpu = batch_prediction_prob_cpu[j, ...] # Already on CPU
808
+ class_label_value_cpu_scalar = torch.tensor(int(current_label_values[j]), dtype=torch.int16, device='cpu') # Already on CPU
809
+
810
+ update_mask_cpu = class_prob_cpu > max_prob
811
+ max_prob[update_mask_cpu] = class_prob_cpu[update_mask_cpu]
812
+ max_class_label_value[update_mask_cpu] = class_label_value_cpu_scalar
813
+
814
+ # Clean up batch tensors
815
+ del temp_prediction_batch_cpu, temp_accumulation_batch_cpu, batch_accumulation_cpu, batch_prediction_prob_cpu, queries
816
+ # Previous patch-level deletions handle GPU memory
817
+
818
+ # Final operations on CPU
819
+ background_indices = max_prob < background_threshold
820
+ max_class_label_value[background_indices] = 0
821
+ results = max_class_label_value.numpy() # Already on CPU, just convert to numpy
822
+
823
+ results = remove_padding(results, padding_info)
824
+ current_h, current_w, current_d = results.shape
825
+ if results.shape != original_shape:
826
+ results = resample_torch_simple(
827
+ results[np.newaxis, ...],
828
+ new_shape=original_shape,
829
+ is_seg=True,
830
+ num_threads=4,
831
+ device=torch.device('cpu'),
832
+ memefficient_seg_resampling=False).squeeze(0)
833
+
834
+ if verbose:
835
+ print(f"Resized segmentation from {current_h, current_w, current_d} to {original_shape}")
836
+
837
+ pred_array = rearrange(results, 'h w d -> d h w').astype(np.int16)
838
+
839
+ if return_max_prob and instance_label == 0:
840
+ # max_prob is already on CPU, just convert to numpy for post-processing
841
+ max_prob_numpy = max_prob.numpy()
842
+ max_prob_numpy = remove_padding(max_prob_numpy, padding_info)
843
+ current_h, current_w, current_d = max_prob_numpy.shape
844
+ if max_prob_numpy.shape != original_shape:
845
+ max_prob_numpy = resample_torch_simple(
846
+ max_prob_numpy[np.newaxis, ...],
847
+ new_shape=original_shape,
848
+ is_seg=False,
849
+ num_threads=4,
850
+ device=torch.device('cpu'),
851
+ memefficient_seg_resampling=False).squeeze(0)
852
+
853
+ if verbose:
854
+ print(f"Resized max probability from {current_h, current_w, current_d} to {original_shape}")
855
+ max_prob = rearrange(max_prob_numpy, 'h w d -> d h w').astype(np.float32)
856
+
857
+ if return_max_prob and instance_label == 0:
858
+ return pred_array, max_prob
859
+ else:
860
+ return pred_array, None
861
+
862
+
863
+ def run_segmentation(
864
+ raw_image,
865
+ raw_spacing,
866
+ crop_size=[192, 192, 96],
867
+ target_spacing=[1.5, 1.5, 3.0],
868
+ target_spacing_model=[1.5, 1.5, 3.0],
869
+ w_lowres_pred_prompts=False,
870
+ scaled_roi_lowres_pred_array=None,
871
+ disable_tta=True,
872
+ model_step=100000,
873
+ vision_backbone="UNET",
874
+ input_channels=2,
875
+ patch_size=[32, 32, 32],
876
+ modality='CT',
877
+ instance_label=0,
878
+ texts=[],
879
+ label_values=[],
880
+ return_max_prob=False,
881
+ class_name_list=[],
882
+ stage_1_flag=False,
883
+ device="cuda:0",
884
+ checkpoints_path="./checkpoints",
885
+ category_batch_size=24,
886
+ background_threshold=0.5,
887
+ verbose=True,
888
+ ):
889
+ """
890
+ Main segmentation function.
891
+
892
+ This function orchestrates the entire segmentation pipeline including
893
+ model loading, data preprocessing, patch-based inference, and result aggregation.
894
+
895
+ Args:
896
+ raw_image: Input image array with shape (d, h, w), dtype uint8, values in [0, 255]
897
+ raw_spacing: Spacing array with shape (3,)
898
+ crop_size: Crop size for patch processing [h, w, d]
899
+ target_spacing: Target spacing for resampling [h, w, d]
900
+ target_spacing_model: Target spacing for model (should match target_spacing)
901
+ w_lowres_pred_prompts: Whether to use low-res predictions as spatial prompts
902
+ scaled_roi_lowres_pred_array: Low-res prediction array for spatial prompts
903
+ disable_tta: Disable test-time augmentation
904
+ model_step: Model checkpoint step number
905
+ vision_backbone: Vision backbone architecture name
906
+ input_channels: Number of input channels
907
+ patch_size: Patch size for the model
908
+ modality: Imaging modality ('CT', 'MRI', 'US', 'PET', 'microscopy')
909
+ instance_label: 0 for semantic segmentation, 1 for instance segmentation
910
+ texts: List of text prompts (one per class)
911
+ label_values: List of label values (one per class)
912
+ return_max_prob: Whether to return maximum probability map
913
+ class_name_list: List of class names for class-specific adjustments
914
+ stage_1_flag: Whether this is Stage 1 inference
915
+ device: Device string (e.g., 'cuda:0' or 'cpu')
916
+ checkpoints_path: Path to model checkpoints directory
917
+ category_batch_size: Number of categories to process in each batch (default: 24)
918
+ Adjust based on GPU memory. Larger 3D images require smaller batch sizes.
919
+ Accumulation operations are performed on CPU for more stable memory usage.
920
+ background_threshold: Probability threshold for background (default: 0.5)
921
+ Voxels with max probability below this threshold will be labeled as background.
922
+ verbose: Whether to print detailed information (default: True)
923
+
924
+ Returns:
925
+ pred_array: Segmentation array with shape (d, h, w), dtype int16
926
+ max_prob_array: Maximum probability array (if return_max_prob=True), or None
927
+ """
928
+ w_lowres_pred_prompts = scaled_roi_lowres_pred_array is not None
929
+ config = {
930
+ 'device': device,
931
+ 'modality': modality,
932
+ 'instance_label': instance_label,
933
+ 'texts': texts,
934
+ 'label_values': label_values,
935
+ 'vision_backbone': vision_backbone,
936
+ 'crop_size': crop_size,
937
+ 'patch_size': patch_size,
938
+ 'target_spacing': target_spacing,
939
+ 'target_spacing_model': target_spacing_model,
940
+ 'model_step': model_step,
941
+ 'input_channels': input_channels,
942
+ 'w_lowres_pred_prompts': w_lowres_pred_prompts,
943
+ 'scaled_roi_lowres_pred_array': scaled_roi_lowres_pred_array,
944
+ 'disable_tta': disable_tta,
945
+ 'checkpoints_path': checkpoints_path,
946
+ 'return_max_prob': return_max_prob,
947
+ 'class_name_list': class_name_list,
948
+ 'stage_1_flag': stage_1_flag,
949
+ 'category_batch_size': category_batch_size,
950
+ 'background_threshold': background_threshold,
951
+ }
952
+
953
+ pipeline = MedicalSegmentationPipeline(config)
954
+ return pipeline.run_inference(raw_image, raw_spacing, verbose=verbose)
955
+
956
+
957
+ # ============================================================================
958
+ # Main Inference Functions
959
+ # ============================================================================
960
+ # These functions provide the high-level interface for running inference
961
+ # on raw NIfTI images with proper preprocessing and post-processing.
962
+ # ============================================================================
963
+
964
+
965
+ def normalize_image_ct(image_data, window_level=40, window_width=400, window_type='soft_tissue'):
966
+ """
967
+ Normalize CT image using window/level technique.
968
+
969
+ Args:
970
+ image_data: Input CT image array
971
+ window_level: Window level (center of the window). If None, will use default based on window_type
972
+ window_width: Window width (range of the window). If None, will use default based on window_type
973
+ window_type: Type of window ('soft_tissue', 'bone', 'lung'). Used if window_level/window_width are None
974
+
975
+ Returns:
976
+ Normalized image array with dtype uint8, values in [0, 255]
977
+ """
978
+ # Default window settings for different window types
979
+ default_windows = {
980
+ 'soft_tissue': {'window_level': 40, 'window_width': 400},
981
+ 'bone': {'window_level': 500, 'window_width': 1500},
982
+ 'lung': {'window_level': -600, 'window_width': 1500}
983
+ }
984
+
985
+ # Use defaults if not provided
986
+ if window_level is None or window_width is None:
987
+ if window_type in default_windows:
988
+ window_level = default_windows[window_type]['window_level']
989
+ window_width = default_windows[window_type]['window_width']
990
+ else:
991
+ # Fallback to soft_tissue defaults
992
+ window_level = default_windows['soft_tissue']['window_level']
993
+ window_width = default_windows['soft_tissue']['window_width']
994
+
995
+ lower_bound = window_level - window_width / 2
996
+ upper_bound = window_level + window_width / 2
997
+ image_data_pre = np.clip(image_data, lower_bound, upper_bound)
998
+ image_data_pre = (
999
+ (image_data_pre - np.min(image_data_pre))
1000
+ / (np.max(image_data_pre) - np.min(image_data_pre) + 1e-8)
1001
+ * 255.0
1002
+ )
1003
+ return image_data_pre.astype(np.uint8)
1004
+
1005
+
1006
+ def normalize_image_other(image_data, percentile_lower=None, percentile_upper=None, preserve_zero=None, normalization_settings=None):
1007
+ """
1008
+ Normalize non-CT images using percentile-based normalization.
1009
+
1010
+ This method clips values to specified percentiles, then
1011
+ normalizes to [0, 255] range while optionally preserving zero values.
1012
+
1013
+ Args:
1014
+ image_data: Input image array
1015
+ percentile_lower: Lower percentile for clipping. If None, will use default or value from normalization_settings
1016
+ percentile_upper: Upper percentile for clipping. If None, will use default or value from normalization_settings
1017
+ preserve_zero: Whether to preserve zero values. If None, will use default or value from normalization_settings
1018
+ normalization_settings: Dictionary containing normalization settings from config.
1019
+ Format: {'percentile_lower': 0.5, 'percentile_upper': 99.5, 'preserve_zero': True}
1020
+
1021
+ Returns:
1022
+ Normalized image array with dtype uint8, values in [0, 255]
1023
+ """
1024
+ # Default normalization settings
1025
+ default_percentile_lower = 0.5
1026
+ default_percentile_upper = 99.5
1027
+ default_preserve_zero = True
1028
+
1029
+ # Use settings from config if provided
1030
+ if normalization_settings is not None:
1031
+ if percentile_lower is None:
1032
+ percentile_lower = normalization_settings.get('percentile_lower', default_percentile_lower)
1033
+ if percentile_upper is None:
1034
+ percentile_upper = normalization_settings.get('percentile_upper', default_percentile_upper)
1035
+ if preserve_zero is None:
1036
+ preserve_zero = normalization_settings.get('preserve_zero', default_preserve_zero)
1037
+ else:
1038
+ # Use defaults if not provided
1039
+ if percentile_lower is None:
1040
+ percentile_lower = default_percentile_lower
1041
+ if percentile_upper is None:
1042
+ percentile_upper = default_percentile_upper
1043
+ if preserve_zero is None:
1044
+ preserve_zero = default_preserve_zero
1045
+
1046
+ # Calculate percentiles from non-zero values
1047
+ non_zero_data = image_data[image_data > 0]
1048
+ if len(non_zero_data) > 0:
1049
+ lower_bound, upper_bound = np.percentile(
1050
+ non_zero_data, [percentile_lower, percentile_upper]
1051
+ )
1052
+ else:
1053
+ # If all values are zero, use min/max
1054
+ lower_bound = np.min(image_data)
1055
+ upper_bound = np.max(image_data)
1056
+
1057
+ image_data_pre = np.clip(image_data, lower_bound, upper_bound)
1058
+ image_data_pre = (
1059
+ (image_data_pre - np.min(image_data_pre))
1060
+ / (np.max(image_data_pre) - np.min(image_data_pre) + 1e-8)
1061
+ * 255.0
1062
+ )
1063
+
1064
+ if preserve_zero:
1065
+ image_data_pre[image_data == 0] = 0
1066
+
1067
+ return image_data_pre.astype(np.uint8)
1068
+
1069
+
1070
+ def load_nifti_image(image_path):
1071
+ """
1072
+ Load NIfTI image and extract data, spacing, and metadata.
1073
+
1074
+ Args:
1075
+ image_path: Path to NIfTI image file
1076
+
1077
+ Returns:
1078
+ image_data: Image array with shape (d, h, w)
1079
+ spacing_xyz: Spacing tuple (x, y, z) from SimpleITK
1080
+ metadata: Dictionary containing origin, direction, and spacing_xyz
1081
+ """
1082
+ img_sitk = sitk.ReadImage(image_path)
1083
+ image_data = sitk.GetArrayFromImage(img_sitk) # Shape: (d, h, w)
1084
+ spacing_xyz = img_sitk.GetSpacing() # (x, y, z)
1085
+
1086
+ # Save metadata for output
1087
+ metadata = {
1088
+ 'origin': img_sitk.GetOrigin(),
1089
+ 'direction': img_sitk.GetDirection(),
1090
+ 'spacing_xyz': spacing_xyz
1091
+ }
1092
+
1093
+ return image_data, spacing_xyz, metadata
1094
+
1095
+
1096
+ def convert_spacing(spacing_xyz, image_shape):
1097
+ """
1098
+ Convert spacing from SimpleITK format (x, y, z) to format expected by run_segmentation.
1099
+
1100
+ Following the conversion logic from inference_raw_nifti_2.py:
1101
+ 1. SimpleITK returns (x, y, z)
1102
+ 2. Image from SimpleITK is (d, h, w) where d=z, h=y, w=x
1103
+ 3. Convert to (d, h, w) spacing: (z, x, y) = (d, h, w)
1104
+ 4. Then convert to format expected by run_segmentation: (h, w, d)
1105
+
1106
+ Args:
1107
+ spacing_xyz: Spacing tuple from SimpleITK (x, y, z)
1108
+ image_shape: Image shape (d, h, w)
1109
+
1110
+ Returns:
1111
+ img_spacing: Spacing array in format expected by run_segmentation
1112
+ """
1113
+ img_spacing = np.array(spacing_xyz, dtype=np.float32)
1114
+
1115
+ # Step 1: Convert from (x, y, z) to (d, h, w) spacing
1116
+ # SimpleITK: (x, y, z) -> Image: (d, h, w) where d=z, h=y, w=x
1117
+ # So spacing (x, y, z) -> (z, x, y) = (d, h, w)
1118
+ img_spacing_transposed = img_spacing[[2, 0, 1]] # (z, x, y) = (d, h, w)
1119
+
1120
+ # Step 2: Handle very small spacing values
1121
+ for i in range(3):
1122
+ if img_spacing_transposed[i] < 0.1:
1123
+ img_spacing_transposed[i] = 1.0
1124
+
1125
+ # Step 3: Optional: Adjust spacing based on image dimensions
1126
+ # Note: adjust_spacing expects image in (h, w, d) format, so we need to rearrange
1127
+ # For now, we'll skip this adjustment or use a dummy array
1128
+ try:
1129
+ img_spacing_transposed = adjust_spacing(
1130
+ np.zeros(image_shape), # Dummy array for shape reference
1131
+ img_spacing_transposed
1132
+ ).astype(np.float32)
1133
+ except Exception:
1134
+ # If adjust_spacing fails, use spacing as-is
1135
+ pass
1136
+
1137
+ # Step 4: Convert to format expected by run_segmentation
1138
+ # This converts (d, h, w) to (h, w, d)
1139
+ img_spacing = img_spacing_transposed[[1, 2, 0]]
1140
+
1141
+ return img_spacing
1142
+
1143
+
1144
+ def run_inference_single_window(
1145
+ image_data,
1146
+ spacing_xyz,
1147
+ metadata,
1148
+ modality='CT',
1149
+ texts=None,
1150
+ label_values=None,
1151
+ inference_mode='stage2_only',
1152
+ device="cuda:0",
1153
+ checkpoints_path="./checkpoints",
1154
+ window_settings=None,
1155
+ window_type='soft_tissue',
1156
+ normalization_settings=None,
1157
+ verbose=True
1158
+ ):
1159
+ """
1160
+ Run inference for a single window type.
1161
+
1162
+ This is an internal function used by run_inference to handle single window type inference.
1163
+
1164
+ Args:
1165
+ image_data: Raw image data array (d, h, w)
1166
+ spacing_xyz: Spacing tuple (x, y, z)
1167
+ metadata: Image metadata dictionary
1168
+ modality: Imaging modality ('CT', 'MRI', 'US', 'PET', 'microscopy')
1169
+ texts: List of text prompts (one per class)
1170
+ label_values: List of label values (one per class)
1171
+ inference_mode: Inference mode ('stage2_only' or 'stage1+stage2')
1172
+ device: Device to use ('cuda:0' or 'cpu')
1173
+ checkpoints_path: Path to model checkpoints
1174
+ window_settings: Dictionary containing window settings for different window types (CT only)
1175
+ window_type: Type of window to use ('soft_tissue', 'bone', 'lung')
1176
+ normalization_settings: Dictionary containing normalization settings for non-CT modalities
1177
+ verbose: Whether to print detailed information (default: True)
1178
+
1179
+ Returns:
1180
+ pred_array: Segmentation array (d, h, w)
1181
+ """
1182
+ if texts is None:
1183
+ texts = []
1184
+ if label_values is None:
1185
+ label_values = []
1186
+
1187
+ if len(texts) != len(label_values):
1188
+ raise ValueError("Number of text prompts must match number of label values")
1189
+
1190
+ # Normalize image
1191
+ if verbose:
1192
+ print(f"Normalizing image for {window_type} window (modality: {modality})")
1193
+ if modality.upper() == 'CT':
1194
+ # Get window settings from config if available
1195
+ window_level = None
1196
+ window_width = None
1197
+ if window_settings is not None and window_type in window_settings:
1198
+ window_level = window_settings[window_type].get('window_level')
1199
+ window_width = window_settings[window_type].get('window_width')
1200
+ if verbose:
1201
+ print(f"Using {window_type} window: level={window_level}, width={window_width}")
1202
+
1203
+ img_array = normalize_image_ct(image_data, window_level=window_level,
1204
+ window_width=window_width, window_type=window_type)
1205
+ else:
1206
+ # Get normalization settings from config if available
1207
+ if normalization_settings is not None:
1208
+ if verbose:
1209
+ print(f"Using normalization settings from config: {normalization_settings}")
1210
+ img_array = normalize_image_other(image_data, normalization_settings=normalization_settings)
1211
+ else:
1212
+ # Use default normalization
1213
+ if verbose:
1214
+ print("Using default normalization settings")
1215
+ img_array = normalize_image_other(image_data)
1216
+
1217
+ if verbose:
1218
+ print(f"Normalized image range: [{img_array.min()}, {img_array.max()}]")
1219
+
1220
+ # Convert spacing
1221
+ img_spacing = convert_spacing(spacing_xyz, img_array.shape)
1222
+ if verbose:
1223
+ print(f"Converted spacing: {img_spacing}")
1224
+
1225
+ # Run inference
1226
+ if inference_mode == 'stage1+stage2':
1227
+ if verbose:
1228
+ print(f"Running two-stage inference with {window_type} window...")
1229
+ # Stage 1: Low-resolution
1230
+ if verbose:
1231
+ print("Stage 1: Low-resolution segmentation...")
1232
+ stage_1_pred, _ = run_segmentation(
1233
+ raw_image=img_array,
1234
+ raw_spacing=img_spacing,
1235
+ crop_size=[224, 224, 128],
1236
+ target_spacing=[1.5, 1.5, 3.0],
1237
+ target_spacing_model=[1.5, 1.5, 3.0],
1238
+ w_lowres_pred_prompts=False,
1239
+ scaled_roi_lowres_pred_array=None,
1240
+ disable_tta=True,
1241
+ model_step=358600,
1242
+ modality=modality.lower(),
1243
+ instance_label=0,
1244
+ texts=texts,
1245
+ label_values=label_values,
1246
+ return_max_prob=False,
1247
+ class_name_list=[],
1248
+ stage_1_flag=True,
1249
+ device=device,
1250
+ checkpoints_path=checkpoints_path,
1251
+ verbose=verbose
1252
+ )
1253
+
1254
+ # Check if Stage 1 found anything
1255
+ if stage_1_pred.sum() == 0:
1256
+ if verbose:
1257
+ print("Warning: Stage 1 found no predictions. Using Stage 1 result as final output.")
1258
+ final_pred = stage_1_pred
1259
+ else:
1260
+ if verbose:
1261
+ print("Stage 1 completed. Extracting ROI for Stage 2...")
1262
+
1263
+ # Remove small objects from Stage 1 prediction
1264
+ min_size = 10
1265
+ lowres_pred_binary = (stage_1_pred > 0).astype(np.int16)
1266
+ lowres_pred_binary = remove_small_objects_binary(lowres_pred_binary, min_size=min_size).astype(np.int16)
1267
+ stage_1_pred_cleaned = stage_1_pred * lowres_pred_binary
1268
+
1269
+ # Extract ROI from Stage 1 prediction
1270
+ # Find bounding box of non-zero regions
1271
+ non_zero_indices = np.argwhere(stage_1_pred_cleaned > 0)
1272
+ if len(non_zero_indices) == 0:
1273
+ if verbose:
1274
+ print("Warning: No non-zero regions after cleaning. Using Stage 1 result.")
1275
+ final_pred = stage_1_pred_cleaned
1276
+ else:
1277
+ z_min, y_min, x_min = non_zero_indices.min(axis=0)
1278
+ z_max, y_max, x_max = non_zero_indices.max(axis=0)
1279
+
1280
+ # Calculate ROI center and range with scaling factor
1281
+ m = 1.1 # Scaling factor for ROI expansion
1282
+ z_center = (z_min + z_max) / 2
1283
+ y_center = (y_min + y_max) / 2
1284
+ x_center = (x_min + x_max) / 2
1285
+
1286
+ z_range = (z_max - z_min + 1) * m / 2
1287
+ y_range = (y_max - y_min + 1) * m / 2
1288
+ x_range = (x_max - x_min + 1) * m / 2
1289
+
1290
+ # Calculate minimum ranges based on Stage 2 crop size and spacing
1291
+ stage_2_crop_size = [192, 192, 192]
1292
+ stage_2_target_spacing = [1.0, 1.0, 1.0]
1293
+
1294
+ img_spacing_for_roi = img_spacing.copy()
1295
+
1296
+ min_z_range = (stage_2_crop_size[2] / 2) * stage_2_target_spacing[2] / img_spacing_for_roi[2] if img_spacing_for_roi[2] > 0 else z_range
1297
+ min_y_range = (stage_2_crop_size[0] / 2) * stage_2_target_spacing[0] / img_spacing_for_roi[0] if img_spacing_for_roi[0] > 0 else y_range
1298
+ min_x_range = (stage_2_crop_size[1] / 2) * stage_2_target_spacing[1] / img_spacing_for_roi[1] if img_spacing_for_roi[1] > 0 else x_range
1299
+
1300
+ z_range = max(min_z_range - 1, z_range)
1301
+ y_range = max(min_y_range - 1, y_range)
1302
+ x_range = max(min_x_range - 1, x_range)
1303
+
1304
+ z_min_new = max(0, int(z_center - z_range))
1305
+ z_max_new = min(stage_1_pred_cleaned.shape[0] - 1, int(z_center + z_range))
1306
+ y_min_new = max(0, int(y_center - y_range))
1307
+ y_max_new = min(stage_1_pred_cleaned.shape[1] - 1, int(y_center + y_range))
1308
+ x_min_new = max(0, int(x_center - x_range))
1309
+ x_max_new = min(stage_1_pred_cleaned.shape[2] - 1, int(x_center + x_range))
1310
+
1311
+ if verbose:
1312
+ print(f"ROI bounds: z=[{z_min_new}:{z_max_new}], y=[{y_min_new}:{y_max_new}], x=[{x_min_new}:{x_max_new}]")
1313
+
1314
+ roi_array = img_array[z_min_new:z_max_new+1, y_min_new:y_max_new+1, x_min_new:x_max_new+1]
1315
+ roi_lowres_pred = stage_1_pred_cleaned[z_min_new:z_max_new+1, y_min_new:y_max_new+1, x_min_new:x_max_new+1]
1316
+
1317
+ if verbose:
1318
+ print(f"ROI image shape: {roi_array.shape}")
1319
+ print(f"ROI prediction shape: {roi_lowres_pred.shape}")
1320
+
1321
+ # Stage 2: High-resolution segmentation on ROI
1322
+ if verbose:
1323
+ print("Stage 2: High-resolution segmentation on ROI...")
1324
+ roi_pred, _ = run_segmentation(
1325
+ raw_image=roi_array,
1326
+ raw_spacing=img_spacing,
1327
+ crop_size=[192, 192, 192],
1328
+ target_spacing=[1.0, 1.0, 1.0],
1329
+ target_spacing_model=[1.0, 1.0, 1.0],
1330
+ w_lowres_pred_prompts=True,
1331
+ scaled_roi_lowres_pred_array=roi_lowres_pred,
1332
+ disable_tta=True,
1333
+ model_step=341300,
1334
+ modality=modality.lower(),
1335
+ instance_label=0,
1336
+ texts=texts,
1337
+ label_values=label_values,
1338
+ return_max_prob=False,
1339
+ class_name_list=[],
1340
+ stage_1_flag=False,
1341
+ device=device,
1342
+ checkpoints_path=checkpoints_path,
1343
+ verbose=verbose
1344
+ )
1345
+
1346
+ # Integrate ROI prediction back into full volume
1347
+ if verbose:
1348
+ print("Integrating Stage 2 results back into full volume...")
1349
+ final_pred = np.zeros_like(stage_1_pred_cleaned, dtype=np.int16)
1350
+ final_pred[z_min_new:z_max_new+1, y_min_new:y_max_new+1, x_min_new:x_max_new+1] = roi_pred
1351
+ if verbose:
1352
+ print("Stage1+Stage2 inference completed.")
1353
+ elif inference_mode == 'stage2_only':
1354
+ if verbose:
1355
+ print(f"Running Stage 2 inference with {window_type} window...")
1356
+ final_pred, _ = run_segmentation(
1357
+ raw_image=img_array,
1358
+ raw_spacing=img_spacing,
1359
+ crop_size=[192, 192, 192],
1360
+ target_spacing=[1.0, 1.0, 1.0],
1361
+ target_spacing_model=[1.0, 1.0, 1.0],
1362
+ w_lowres_pred_prompts=False,
1363
+ scaled_roi_lowres_pred_array=None,
1364
+ disable_tta=True,
1365
+ model_step=341300,
1366
+ modality=modality.lower(),
1367
+ instance_label=0,
1368
+ texts=texts,
1369
+ label_values=label_values,
1370
+ return_max_prob=False,
1371
+ class_name_list=[],
1372
+ stage_1_flag=False,
1373
+ device=device,
1374
+ checkpoints_path=checkpoints_path,
1375
+ verbose=verbose
1376
+ )
1377
+ else:
1378
+ raise ValueError(f"Unknown inference mode: {inference_mode}. Must be 'stage2_only' or 'stage1+stage2'")
1379
+
1380
+ return final_pred
1381
+
1382
+
1383
+ def run_inference(
1384
+ image_path,
1385
+ output_path,
1386
+ modality='CT',
1387
+ texts=None,
1388
+ label_values=None,
1389
+ inference_mode='stage2_only',
1390
+ device="cuda:0",
1391
+ checkpoints_path="./checkpoints",
1392
+ window_settings=None,
1393
+ window_type='soft_tissue',
1394
+ normalization_settings=None,
1395
+ window_type_mapping=None,
1396
+ verbose=True
1397
+ ):
1398
+ """
1399
+ Run Medal-S inference on a raw NIfTI image.
1400
+
1401
+ Supports multi-window inference for CT images: if multiple window types are specified
1402
+ (e.g., soft_tissue, bone, lung), each window type will be processed separately with
1403
+ its corresponding window settings, and results will be merged.
1404
+
1405
+ Args:
1406
+ image_path: Path to input NIfTI image
1407
+ output_path: Path to save output segmentation (will be modified with mode suffix)
1408
+ modality: Imaging modality ('CT', 'MRI', 'US', 'PET', 'microscopy')
1409
+ texts: List of text prompts (one per class)
1410
+ label_values: List of label values (one per class)
1411
+ inference_mode: Inference mode ('stage2_only' or 'stage1+stage2')
1412
+ device: Device to use ('cuda:0' or 'cpu')
1413
+ checkpoints_path: Path to model checkpoints
1414
+ window_settings: Dictionary containing window settings for different window types (CT only).
1415
+ Format: {'soft_tissue': {'window_level': 40, 'window_width': 400}, ...}
1416
+ window_type: Type of window to use ('soft_tissue', 'bone', 'lung'). Default: 'soft_tissue' (CT only)
1417
+ Ignored if window_type_mapping indicates multiple window types
1418
+ normalization_settings: Dictionary containing normalization settings for non-CT modalities.
1419
+ Format: {'percentile_lower': 0.5, 'percentile_upper': 99.5, 'preserve_zero': True}
1420
+ window_type_mapping: Dictionary mapping each text to its window type.
1421
+ Format: {'text1': 'soft_tissue', 'text2': 'bone', ...}
1422
+ If provided and contains multiple window types, will perform separate inference for each
1423
+ verbose: Whether to print detailed information (default: True)
1424
+
1425
+ Returns:
1426
+ pred_array: Segmentation array (d, h, w)
1427
+ inference_time: Total inference time in seconds
1428
+ """
1429
+ if texts is None:
1430
+ texts = []
1431
+ if label_values is None:
1432
+ label_values = []
1433
+
1434
+ if len(texts) != len(label_values):
1435
+ raise ValueError("Number of text prompts must match number of label values")
1436
+
1437
+ # Add mode suffix to output filename
1438
+ if inference_mode == 'stage1+stage2':
1439
+ suffix = '_stage1+stage2'
1440
+ elif inference_mode == 'stage2_only':
1441
+ suffix = '_stage2_only'
1442
+ else:
1443
+ suffix = f'_{inference_mode}'
1444
+
1445
+ # Modify output path to include suffix
1446
+ base_path, ext = os.path.splitext(output_path)
1447
+ if ext == '.gz': # Handle .nii.gz
1448
+ base_path, nii_ext = os.path.splitext(base_path)
1449
+ output_path = f"{base_path}{suffix}{nii_ext}{ext}"
1450
+ else:
1451
+ output_path = f"{base_path}{suffix}{ext}"
1452
+
1453
+ if verbose:
1454
+ print(f"Output will be saved to: {output_path}")
1455
+
1456
+ # Start timing
1457
+ start_time = time.time()
1458
+
1459
+ # Load image
1460
+ if verbose:
1461
+ print(f"Loading image: {image_path}")
1462
+ image_data, spacing_xyz, metadata = load_nifti_image(image_path)
1463
+ if verbose:
1464
+ print(f"Image shape: {image_data.shape}")
1465
+ print(f"Original spacing (x, y, z): {spacing_xyz}")
1466
+
1467
+ # Determine inference strategy based on modality and window types
1468
+ if modality.upper() == 'CT':
1469
+ # CT modality: check for multiple window types
1470
+ if window_type_mapping is not None:
1471
+ window_types = list(set(window_type_mapping.values()))
1472
+ if len(window_types) > 1:
1473
+ # Multiple window types: perform separate inference for each window type
1474
+ if verbose:
1475
+ print(f"\n{'='*60}")
1476
+ print(f"CT with {len(window_types)} window types detected: {window_types}")
1477
+ print("Performing separate inference for each window type...")
1478
+ print(f"{'='*60}\n")
1479
+
1480
+ all_predictions = []
1481
+
1482
+ for wt in window_types:
1483
+ if verbose:
1484
+ print(f"\n{'='*60}")
1485
+ print(f"Processing {wt} window type...")
1486
+ print(f"{'='*60}\n")
1487
+
1488
+ # Filter texts and label_values for this window type
1489
+ wt_texts = [text for text in texts if window_type_mapping.get(text) == wt]
1490
+ wt_indices = [i for i, text in enumerate(texts) if window_type_mapping.get(text) == wt]
1491
+ wt_label_values = [label_values[i] for i in wt_indices]
1492
+
1493
+ if len(wt_texts) == 0:
1494
+ if verbose:
1495
+ print(f"No classes for {wt} window type, skipping...")
1496
+ continue
1497
+
1498
+ if verbose:
1499
+ print(f"Classes for {wt} window: {len(wt_texts)}")
1500
+ print(f" Texts: {wt_texts}")
1501
+ print(f" Labels: {wt_label_values}")
1502
+
1503
+ # Run inference for this window type with its specific window settings
1504
+ wt_pred = run_inference_single_window(
1505
+ image_data=image_data,
1506
+ spacing_xyz=spacing_xyz,
1507
+ metadata=metadata,
1508
+ modality=modality,
1509
+ texts=wt_texts,
1510
+ label_values=wt_label_values,
1511
+ inference_mode=inference_mode,
1512
+ device=device,
1513
+ checkpoints_path=checkpoints_path,
1514
+ window_settings=window_settings,
1515
+ window_type=wt, # Use the specific window type
1516
+ normalization_settings=normalization_settings,
1517
+ verbose=verbose
1518
+ )
1519
+
1520
+ all_predictions.append((wt_pred, wt_label_values))
1521
+
1522
+ # Merge predictions: use maximum label value when overlapping
1523
+ if verbose:
1524
+ print(f"\n{'='*60}")
1525
+ print("Merging predictions from all window types...")
1526
+ print(f"{'='*60}\n")
1527
+
1528
+ final_pred = np.zeros_like(all_predictions[0][0], dtype=np.int16)
1529
+ for wt_pred, wt_labels in all_predictions:
1530
+ # For each label in this window type's prediction
1531
+ for label_val in wt_labels:
1532
+ label_int = int(label_val)
1533
+ mask = (wt_pred == label_int)
1534
+ # Only update if current prediction is background (0) or smaller label
1535
+ final_pred[mask] = np.maximum(final_pred[mask], label_int)
1536
+
1537
+ if verbose:
1538
+ print("Merging completed.")
1539
+ else:
1540
+ # Single window type: use the specific window type
1541
+ if len(window_types) == 1:
1542
+ window_type = window_types[0]
1543
+ if verbose:
1544
+ print(f"CT with single window type: {window_type}")
1545
+
1546
+ final_pred = run_inference_single_window(
1547
+ image_data=image_data,
1548
+ spacing_xyz=spacing_xyz,
1549
+ metadata=metadata,
1550
+ modality=modality,
1551
+ texts=texts,
1552
+ label_values=label_values,
1553
+ inference_mode=inference_mode,
1554
+ device=device,
1555
+ checkpoints_path=checkpoints_path,
1556
+ window_settings=window_settings,
1557
+ window_type=window_type, # Use the determined window type
1558
+ normalization_settings=normalization_settings,
1559
+ verbose=verbose
1560
+ )
1561
+ else:
1562
+ # No window_type_mapping: use default window_type
1563
+ if verbose:
1564
+ print(f"CT without window_type_mapping, using window type: {window_type}")
1565
+ final_pred = run_inference_single_window(
1566
+ image_data=image_data,
1567
+ spacing_xyz=spacing_xyz,
1568
+ metadata=metadata,
1569
+ modality=modality,
1570
+ texts=texts,
1571
+ label_values=label_values,
1572
+ inference_mode=inference_mode,
1573
+ device=device,
1574
+ checkpoints_path=checkpoints_path,
1575
+ window_settings=window_settings,
1576
+ window_type=window_type,
1577
+ normalization_settings=normalization_settings,
1578
+ verbose=verbose
1579
+ )
1580
+ else:
1581
+ # Non-CT modality: use normalization_settings (other normalization)
1582
+ if verbose:
1583
+ print(f"Non-CT modality ({modality}): using normalization_settings")
1584
+ final_pred = run_inference_single_window(
1585
+ image_data=image_data,
1586
+ spacing_xyz=spacing_xyz,
1587
+ metadata=metadata,
1588
+ modality=modality,
1589
+ texts=texts,
1590
+ label_values=label_values,
1591
+ inference_mode=inference_mode,
1592
+ device=device,
1593
+ checkpoints_path=checkpoints_path,
1594
+ window_settings=window_settings, # Not used for non-CT
1595
+ window_type=window_type, # Not used for non-CT
1596
+ normalization_settings=normalization_settings, # Used for non-CT
1597
+ verbose=verbose
1598
+ )
1599
+
1600
+ # End timing
1601
+ end_time = time.time()
1602
+ inference_time = end_time - start_time
1603
+
1604
+ if verbose:
1605
+ print(f"\n{'='*60}")
1606
+ print(f"Inference Mode: {inference_mode}")
1607
+ print(f"Total Inference Time: {inference_time:.2f} seconds ({inference_time/60:.2f} minutes)")
1608
+ print(f"{'='*60}\n")
1609
+
1610
+ # Save result
1611
+ if verbose:
1612
+ print(f"Saving segmentation to: {output_path}")
1613
+ seg_sitk = sitk.GetImageFromArray(final_pred.astype(np.int16))
1614
+ seg_sitk.SetSpacing(metadata['spacing_xyz'])
1615
+ seg_sitk.SetOrigin(metadata['origin'])
1616
+ seg_sitk.SetDirection(metadata['direction'])
1617
+ sitk.WriteImage(seg_sitk, output_path)
1618
+ if verbose:
1619
+ print(f"Successfully saved segmentation to: {output_path}")
1620
+
1621
+ return final_pred, inference_time
1622
+
1623
+
1624
+ def load_config_from_json(config_path):
1625
+ """
1626
+ Load configuration from JSON file.
1627
+
1628
+ Supports two formats:
1629
+ 1. Legacy format: single 'texts' array
1630
+ 2. New format: separate arrays for 'texts_soft_tissue', 'texts_bone', 'texts_lung'
1631
+
1632
+ If 'labels' field is missing or empty, automatically generates consecutive
1633
+ integer labels starting from 1 (i.e., [1, 2, 3, ..., n] where n is the
1634
+ number of texts).
1635
+
1636
+ Args:
1637
+ config_path: Path to JSON configuration file
1638
+
1639
+ Returns:
1640
+ config: Dictionary containing configuration parameters with processed labels
1641
+
1642
+ Example:
1643
+ # Legacy format:
1644
+ {"texts": ["Aorta", "Liver"], "labels": [1, 2]}
1645
+
1646
+ # New format with window types:
1647
+ {
1648
+ "texts_soft_tissue": ["Aorta", "Liver"],
1649
+ "texts_bone": ["Vertebrae C1"],
1650
+ "texts_lung": ["Left lung"],
1651
+ "window_settings": {
1652
+ "soft_tissue": {"window_level": 40, "window_width": 400},
1653
+ "bone": {"window_level": 400, "window_width": 1500},
1654
+ "lung": {"window_level": -600, "window_width": 1500}
1655
+ }
1656
+ }
1657
+ """
1658
+ with open(config_path, 'r', encoding='utf-8') as f:
1659
+ config = json.load(f)
1660
+
1661
+ # Check if using new format (separate window types)
1662
+ has_window_types = any(key in config for key in ['texts_soft_tissue', 'texts_bone', 'texts_lung'])
1663
+
1664
+ if has_window_types:
1665
+ # New format: combine all texts from different window types
1666
+ texts_soft_tissue = config.get('texts_soft_tissue', [])
1667
+ texts_bone = config.get('texts_bone', [])
1668
+ texts_lung = config.get('texts_lung', [])
1669
+
1670
+ # Combine all texts in order: soft_tissue, bone, lung
1671
+ texts = texts_soft_tissue + texts_bone + texts_lung
1672
+
1673
+ # Store window type mapping for each text
1674
+ window_type_mapping = {}
1675
+ for text in texts_soft_tissue:
1676
+ window_type_mapping[text] = 'soft_tissue'
1677
+ for text in texts_bone:
1678
+ window_type_mapping[text] = 'bone'
1679
+ for text in texts_lung:
1680
+ window_type_mapping[text] = 'lung'
1681
+
1682
+ config['texts'] = texts
1683
+ config['window_type_mapping'] = window_type_mapping
1684
+ else:
1685
+ # Legacy format: single texts array
1686
+ texts = config.get('texts', [])
1687
+ # Default all texts to soft_tissue window type for backward compatibility
1688
+ window_type_mapping = {text: 'soft_tissue' for text in texts}
1689
+ config['window_type_mapping'] = window_type_mapping
1690
+
1691
+ # Process labels: auto-generate if missing or empty
1692
+ texts = config.get('texts', [])
1693
+ labels = config.get('labels', None)
1694
+
1695
+ if labels is None or len(labels) == 0:
1696
+ # Auto-generate consecutive labels starting from 1
1697
+ labels = list(range(1, len(texts) + 1))
1698
+ print(f" Auto-generated consecutive labels: {labels}")
1699
+ else:
1700
+ # Convert labels to integers (handle both string and integer inputs)
1701
+ labels = [int(label) for label in labels]
1702
+
1703
+ # Validate that number of labels matches number of texts
1704
+ if len(labels) != len(texts):
1705
+ raise ValueError(
1706
+ f"Number of labels ({len(labels)}) must match number of texts ({len(texts)}). "
1707
+ f"Texts: {len(texts)}, Labels: {len(labels)}"
1708
+ )
1709
+
1710
+ config['labels'] = labels
1711
+ return config
1712
+
1713
+
1714
+ def main():
1715
+ """
1716
+ Main entry point for the inference script.
1717
+
1718
+ Parses command-line arguments and runs inference with the specified
1719
+ configuration.
1720
+ """
1721
+ parser = argparse.ArgumentParser(
1722
+ description="Medal-S inference for raw NIfTI images",
1723
+ formatter_class=argparse.RawDescriptionHelpFormatter,
1724
+ epilog="""
1725
+ Examples:
1726
+ # Using JSON configuration file:
1727
+ python inference_medals.py --input image.nii.gz --output result.nii.gz \\
1728
+ --config config.json --mode stage2_only
1729
+
1730
+ # Using command-line arguments:
1731
+ python inference_medals.py --input image.nii.gz --output result.nii.gz \\
1732
+ --modality CT --texts "Aorta in CT" --labels 1 --mode stage1+stage2
1733
+ """
1734
+ )
1735
+ parser.add_argument(
1736
+ "--input", "-i",
1737
+ type=str,
1738
+ required=True,
1739
+ help="Path to input NIfTI image"
1740
+ )
1741
+ parser.add_argument(
1742
+ "--output", "-o",
1743
+ type=str,
1744
+ required=True,
1745
+ help="Path to save output segmentation (suffix will be added automatically based on inference mode)"
1746
+ )
1747
+ parser.add_argument(
1748
+ "--config", "-c",
1749
+ type=str,
1750
+ default=None,
1751
+ help="Path to JSON configuration file (if provided, will override --texts, --labels, --modality)"
1752
+ )
1753
+ parser.add_argument(
1754
+ "--modality", "-m",
1755
+ type=str,
1756
+ default="CT",
1757
+ choices=['CT', 'MRI', 'US', 'PET', 'microscopy'],
1758
+ help="Imaging modality (default: CT, ignored if --config is provided)"
1759
+ )
1760
+ parser.add_argument(
1761
+ "--texts",
1762
+ type=str,
1763
+ nargs='+',
1764
+ default=None,
1765
+ help="Text prompts (one per class, ignored if --config is provided)"
1766
+ )
1767
+ parser.add_argument(
1768
+ "--labels",
1769
+ type=str,
1770
+ nargs='+',
1771
+ default=None,
1772
+ help="Label values (one per class, must match texts, ignored if --config is provided)"
1773
+ )
1774
+ parser.add_argument(
1775
+ "--mode",
1776
+ type=str,
1777
+ default="stage2_only",
1778
+ choices=['stage2_only', 'stage1+stage2'],
1779
+ help="Inference mode: 'stage2_only' (default) or 'stage1+stage2'"
1780
+ )
1781
+ parser.add_argument(
1782
+ "--device",
1783
+ type=str,
1784
+ default="cuda:0",
1785
+ help="Device to use (default: cuda:0)"
1786
+ )
1787
+ parser.add_argument(
1788
+ "--checkpoints",
1789
+ type=str,
1790
+ default="./checkpoints",
1791
+ help="Path to model checkpoints (default: ./checkpoints)"
1792
+ )
1793
+ parser.add_argument(
1794
+ "--verbose", "-v",
1795
+ action='store_true',
1796
+ default=False,
1797
+ help="Print detailed information during inference (default: False)"
1798
+ )
1799
+
1800
+ args = parser.parse_args()
1801
+ verbose = args.verbose
1802
+
1803
+ # Load configuration from JSON file if provided
1804
+ window_settings = None
1805
+ window_type = 'soft_tissue'
1806
+ normalization_settings = None
1807
+ window_type_mapping = None
1808
+
1809
+ if args.config:
1810
+ if not os.path.exists(args.config):
1811
+ raise FileNotFoundError(f"Configuration file not found: {args.config}")
1812
+ config = load_config_from_json(args.config)
1813
+ texts = config.get('texts', [])
1814
+ labels = config.get('labels', [])
1815
+ modality = config.get('modality', 'CT')
1816
+ window_settings = config.get('window_settings')
1817
+ normalization_settings = config.get('normalization_settings')
1818
+ window_type_mapping = config.get('window_type_mapping')
1819
+
1820
+ # Determine default window type based on texts (for CT only, used as fallback)
1821
+ if modality.upper() == 'CT':
1822
+ if window_type_mapping:
1823
+ window_types = list(set(window_type_mapping.values()))
1824
+ if len(window_types) == 1:
1825
+ window_type = window_types[0]
1826
+ else:
1827
+ # Default to soft_tissue if mixed types (will be handled by multi-window inference)
1828
+ window_type = 'soft_tissue'
1829
+
1830
+ # Convert labels to strings for compatibility with run_segmentation
1831
+ # (run_segmentation expects string labels)
1832
+ label_values = [str(label) for label in labels]
1833
+
1834
+ if verbose:
1835
+ print(f"Loaded configuration from: {args.config}")
1836
+ print(f" Modality: {modality}")
1837
+ print(f" Number of classes: {len(texts)}")
1838
+ print(f" Labels: {labels}")
1839
+ if modality.upper() == 'CT' and window_settings:
1840
+ print(f" Window settings available for: {list(window_settings.keys())}")
1841
+ if window_type_mapping:
1842
+ window_types = list(set(window_type_mapping.values()))
1843
+ if len(window_types) > 1:
1844
+ print(f" Multiple window types detected: {window_types}")
1845
+ print(f" Will perform separate inference for each window type")
1846
+ else:
1847
+ print(f" Using window type: {window_type}")
1848
+ else:
1849
+ print(f" Using window type: {window_type}")
1850
+ elif normalization_settings:
1851
+ print(f" Normalization settings: {normalization_settings}")
1852
+ else:
1853
+ # Use command line arguments
1854
+ if args.texts is None or args.labels is None:
1855
+ raise ValueError("Either --config or both --texts and --labels must be provided")
1856
+ texts = args.texts
1857
+ label_values = args.labels
1858
+ modality = args.modality
1859
+
1860
+ # Create output directory if needed
1861
+ output_dir = os.path.dirname(args.output)
1862
+ if output_dir and not os.path.exists(output_dir):
1863
+ os.makedirs(output_dir, exist_ok=True)
1864
+
1865
+ # Run inference
1866
+ run_inference(
1867
+ image_path=args.input,
1868
+ output_path=args.output,
1869
+ modality=modality,
1870
+ texts=texts,
1871
+ label_values=label_values,
1872
+ inference_mode=args.mode,
1873
+ device=args.device,
1874
+ checkpoints_path=args.checkpoints,
1875
+ window_settings=window_settings,
1876
+ window_type=window_type,
1877
+ normalization_settings=normalization_settings,
1878
+ window_type_mapping=window_type_mapping,
1879
+ verbose=verbose
1880
+ )
1881
+
1882
+
1883
+ if __name__ == '__main__':
1884
+ main()
1885
+
model/SwinUNETR.py ADDED
@@ -0,0 +1,1116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Sequence, Tuple, Type, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import torch.utils.checkpoint as checkpoint
8
+ from torch.nn import LayerNorm
9
+
10
+ from monai.networks.blocks import MLPBlock as Mlp
11
+ from monai.networks.blocks import PatchEmbed, UnetOutBlock, UnetrBasicBlock, UnetrUpBlock
12
+ from monai.networks.layers import DropPath, trunc_normal_
13
+ from monai.utils import ensure_tuple_rep, optional_import
14
+
15
+ rearrange, _ = optional_import("einops", name="rearrange")
16
+
17
+
18
+ class SwinUNETR_Enc(nn.Module):
19
+ """
20
+ Swin UNETR based on: "Hatamizadeh et al.,
21
+ Swin UNETR: Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images
22
+ <https://arxiv.org/abs/2201.01266>"
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ img_size: Union[Sequence[int], int],
28
+ in_channels: int,
29
+ depths: Sequence[int] = (2, 2, 2, 2),
30
+ num_heads: Sequence[int] = (3, 6, 12, 24),
31
+ feature_size: int = 24,
32
+ norm_name: Union[Tuple, str] = "instance",
33
+ drop_rate: float = 0.0,
34
+ attn_drop_rate: float = 0.0,
35
+ dropout_path_rate: float = 0.0,
36
+ normalize: bool = True,
37
+ use_checkpoint: bool = False,
38
+ spatial_dims: int = 3,
39
+ return_skips: bool = True,
40
+ ) -> None:
41
+ """
42
+ Args:
43
+ img_size: dimension of input image.
44
+ in_channels: dimension of input channels.
45
+ out_channels: dimension of output channels.
46
+ feature_size: dimension of network feature size.
47
+ depths: number of layers in each stage.
48
+ num_heads: number of attention heads.
49
+ norm_name: feature normalization type and arguments.
50
+ drop_rate: dropout rate.
51
+ attn_drop_rate: attention dropout rate.
52
+ dropout_path_rate: drop path rate.
53
+ normalize: normalize output intermediate features in each stage.
54
+ use_checkpoint: use gradient checkpointing for reduced memory usage.
55
+ spatial_dims: number of spatial dims.
56
+ """
57
+
58
+ super().__init__()
59
+
60
+ self.return_skips = return_skips
61
+
62
+ img_size = ensure_tuple_rep(img_size, spatial_dims)
63
+ patch_size = ensure_tuple_rep(2, spatial_dims)
64
+ window_size = ensure_tuple_rep(7, spatial_dims)
65
+
66
+ if not (spatial_dims == 2 or spatial_dims == 3):
67
+ raise ValueError("spatial dimension should be 2 or 3.")
68
+
69
+ for m, p in zip(img_size, patch_size):
70
+ for i in range(5):
71
+ if m % np.power(p, i + 1) != 0:
72
+ raise ValueError("input image size (img_size) should be divisible by stage-wise image resolution.")
73
+
74
+ if not (0 <= drop_rate <= 1):
75
+ raise ValueError("dropout rate should be between 0 and 1.")
76
+
77
+ if not (0 <= attn_drop_rate <= 1):
78
+ raise ValueError("attention dropout rate should be between 0 and 1.")
79
+
80
+ if not (0 <= dropout_path_rate <= 1):
81
+ raise ValueError("drop path rate should be between 0 and 1.")
82
+
83
+ if feature_size % 12 != 0:
84
+ raise ValueError("feature_size should be divisible by 12.")
85
+
86
+ self.normalize = normalize
87
+
88
+ self.swinViT = SwinTransformer(
89
+ in_chans=in_channels,
90
+ embed_dim=feature_size,
91
+ window_size=window_size,
92
+ patch_size=patch_size,
93
+ depths=depths,
94
+ num_heads=num_heads,
95
+ mlp_ratio=4.0,
96
+ qkv_bias=True,
97
+ drop_rate=drop_rate,
98
+ attn_drop_rate=attn_drop_rate,
99
+ drop_path_rate=dropout_path_rate,
100
+ norm_layer=nn.LayerNorm,
101
+ use_checkpoint=use_checkpoint,
102
+ spatial_dims=spatial_dims,
103
+ )
104
+
105
+ self.encoder1 = UnetrBasicBlock( # 2 conv layers
106
+ spatial_dims=spatial_dims,
107
+ in_channels=in_channels,
108
+ out_channels=feature_size,
109
+ kernel_size=3,
110
+ stride=1,
111
+ norm_name=norm_name,
112
+ res_block=True,
113
+ )
114
+
115
+ self.encoder2 = UnetrBasicBlock(
116
+ spatial_dims=spatial_dims,
117
+ in_channels=feature_size,
118
+ out_channels=feature_size,
119
+ kernel_size=3,
120
+ stride=1,
121
+ norm_name=norm_name,
122
+ res_block=True,
123
+ )
124
+
125
+ self.encoder3 = UnetrBasicBlock(
126
+ spatial_dims=spatial_dims,
127
+ in_channels=2 * feature_size,
128
+ out_channels=2 * feature_size,
129
+ kernel_size=3,
130
+ stride=1,
131
+ norm_name=norm_name,
132
+ res_block=True,
133
+ )
134
+
135
+ self.encoder4 = UnetrBasicBlock(
136
+ spatial_dims=spatial_dims,
137
+ in_channels=4 * feature_size,
138
+ out_channels=4 * feature_size,
139
+ kernel_size=3,
140
+ stride=1,
141
+ norm_name=norm_name,
142
+ res_block=True,
143
+ )
144
+
145
+ self.encoder5 = UnetrBasicBlock(
146
+ spatial_dims=spatial_dims,
147
+ in_channels=8 * feature_size,
148
+ out_channels=8 * feature_size,
149
+ kernel_size=3,
150
+ stride=1,
151
+ norm_name=norm_name,
152
+ res_block=True,
153
+ )
154
+
155
+ self.encoder6 = UnetrBasicBlock(
156
+ spatial_dims=spatial_dims,
157
+ in_channels=16 * feature_size,
158
+ out_channels=16 * feature_size,
159
+ kernel_size=3,
160
+ stride=1,
161
+ norm_name=norm_name,
162
+ res_block=True,
163
+ )
164
+
165
+ def load_from(self, weights):
166
+
167
+ with torch.no_grad():
168
+ self.swinViT.patch_embed.proj.weight.copy_(weights["state_dict"]["module.patch_embed.proj.weight"])
169
+ self.swinViT.patch_embed.proj.bias.copy_(weights["state_dict"]["module.patch_embed.proj.bias"])
170
+ for bname, block in self.swinViT.layers1[0].blocks.named_children():
171
+ block.load_from(weights, n_block=bname, layer="layers1")
172
+ self.swinViT.layers1[0].downsample.reduction.weight.copy_(
173
+ weights["state_dict"]["module.layers1.0.downsample.reduction.weight"]
174
+ )
175
+ self.swinViT.layers1[0].downsample.norm.weight.copy_(
176
+ weights["state_dict"]["module.layers1.0.downsample.norm.weight"]
177
+ )
178
+ self.swinViT.layers1[0].downsample.norm.bias.copy_(
179
+ weights["state_dict"]["module.layers1.0.downsample.norm.bias"]
180
+ )
181
+ for bname, block in self.swinViT.layers2[0].blocks.named_children():
182
+ block.load_from(weights, n_block=bname, layer="layers2")
183
+ self.swinViT.layers2[0].downsample.reduction.weight.copy_(
184
+ weights["state_dict"]["module.layers2.0.downsample.reduction.weight"]
185
+ )
186
+ self.swinViT.layers2[0].downsample.norm.weight.copy_(
187
+ weights["state_dict"]["module.layers2.0.downsample.norm.weight"]
188
+ )
189
+ self.swinViT.layers2[0].downsample.norm.bias.copy_(
190
+ weights["state_dict"]["module.layers2.0.downsample.norm.bias"]
191
+ )
192
+ for bname, block in self.swinViT.layers3[0].blocks.named_children():
193
+ block.load_from(weights, n_block=bname, layer="layers3")
194
+ self.swinViT.layers3[0].downsample.reduction.weight.copy_(
195
+ weights["state_dict"]["module.layers3.0.downsample.reduction.weight"]
196
+ )
197
+ self.swinViT.layers3[0].downsample.norm.weight.copy_(
198
+ weights["state_dict"]["module.layers3.0.downsample.norm.weight"]
199
+ )
200
+ self.swinViT.layers3[0].downsample.norm.bias.copy_(
201
+ weights["state_dict"]["module.layers3.0.downsample.norm.bias"]
202
+ )
203
+ for bname, block in self.swinViT.layers4[0].blocks.named_children():
204
+ block.load_from(weights, n_block=bname, layer="layers4")
205
+ self.swinViT.layers4[0].downsample.reduction.weight.copy_(
206
+ weights["state_dict"]["module.layers4.0.downsample.reduction.weight"]
207
+ )
208
+ self.swinViT.layers4[0].downsample.norm.weight.copy_(
209
+ weights["state_dict"]["module.layers4.0.downsample.norm.weight"]
210
+ )
211
+ self.swinViT.layers4[0].downsample.norm.bias.copy_(
212
+ weights["state_dict"]["module.layers4.0.downsample.norm.bias"]
213
+ )
214
+
215
+ def forward(self, x_in):
216
+ # print(x_in.shape, task_id.shape)
217
+ hidden_states_out = self.swinViT(x_in, self.normalize)
218
+
219
+ enc0 = self.encoder1(x_in)
220
+ enc1 = self.encoder2(hidden_states_out[0])
221
+ enc2 = self.encoder3(hidden_states_out[1])
222
+ enc3 = self.encoder4(hidden_states_out[2])
223
+ enc4 = self.encoder5(hidden_states_out[3])
224
+ dec4 = self.encoder6(hidden_states_out[4])
225
+ # print(x_in.shape, enc0.shape, enc1.shape, enc2.shape, enc3.shape, dec4.shape)
226
+ # torch.Size([6, 1, 64, 64, 64]) torch.Size([6, 48, 64, 64, 64]) torch.Size([6, 48, 32, 32, 32])
227
+ # torch.Size([6, 96, 16, 16, 16]) torch.Size([6, 192, 8,8, 8]) torch.Size([6, 768, 2, 2, 2])
228
+
229
+ if self.return_skips:
230
+ return [enc0, enc1, enc2, enc3, enc4, dec4]
231
+ else:
232
+ return [dec4]
233
+
234
+ class SwinUNETR(nn.Module):
235
+ """
236
+ Swin UNETR based on: "Hatamizadeh et al.,
237
+ Swin UNETR: Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images
238
+ <https://arxiv.org/abs/2201.01266>"
239
+ """
240
+
241
+ def __init__(
242
+ self,
243
+ img_size: Union[Sequence[int], int],
244
+ in_channels: int,
245
+ depths: Sequence[int] = (2, 2, 2, 2),
246
+ num_heads: Sequence[int] = (3, 6, 12, 24),
247
+ feature_size: int = 24,
248
+ norm_name: Union[Tuple, str] = "instance",
249
+ drop_rate: float = 0.0,
250
+ attn_drop_rate: float = 0.0,
251
+ dropout_path_rate: float = 0.0,
252
+ normalize: bool = True,
253
+ use_checkpoint: bool = False,
254
+ spatial_dims: int = 3,
255
+ encoding: Union[Tuple, str] = 'rand_embedding', ## rand_embedding or word_embedding
256
+ deep_supervision: bool = True,
257
+ return_skips: bool = True,
258
+ ) -> None:
259
+ """
260
+ Args:
261
+ img_size: dimension of input image.
262
+ in_channels: dimension of input channels.
263
+ out_channels: dimension of output channels.
264
+ feature_size: dimension of network feature size.
265
+ depths: number of layers in each stage.
266
+ num_heads: number of attention heads.
267
+ norm_name: feature normalization type and arguments.
268
+ drop_rate: dropout rate.
269
+ attn_drop_rate: attention dropout rate.
270
+ dropout_path_rate: drop path rate.
271
+ normalize: normalize output intermediate features in each stage.
272
+ use_checkpoint: use gradient checkpointing for reduced memory usage.
273
+ spatial_dims: number of spatial dims.
274
+ Examples::
275
+ # for 3D single channel input with size (96,96,96), 4-channel output and feature size of 48.
276
+ >>> net = SwinUNETR(img_size=(96,96,96), in_channels=1, out_channels=4, feature_size=48)
277
+ # for 3D 4-channel input with size (128,128,128), 3-channel output and (2,4,2,2) layers in each stage.
278
+ >>> net = SwinUNETR(img_size=(128,128,128), in_channels=4, out_channels=3, depths=(2,4,2,2))
279
+ # for 2D single channel input with size (96,96), 2-channel output and gradient checkpointing.
280
+ >>> net = SwinUNETR(img_size=(96,96), in_channels=3, out_channels=2, use_checkpoint=True, spatial_dims=2)
281
+ """
282
+
283
+ super().__init__()
284
+
285
+ self.deep_supervision = deep_supervision
286
+ self.return_skips = return_skips
287
+
288
+ self.encoding = encoding
289
+
290
+ img_size = ensure_tuple_rep(img_size, spatial_dims)
291
+ patch_size = ensure_tuple_rep(2, spatial_dims)
292
+ window_size = ensure_tuple_rep(7, spatial_dims)
293
+
294
+ if not (spatial_dims == 2 or spatial_dims == 3):
295
+ raise ValueError("spatial dimension should be 2 or 3.")
296
+
297
+ for m, p in zip(img_size, patch_size):
298
+ for i in range(5):
299
+ if m % np.power(p, i + 1) != 0:
300
+ raise ValueError("input image size (img_size) should be divisible by stage-wise image resolution.")
301
+
302
+ if not (0 <= drop_rate <= 1):
303
+ raise ValueError("dropout rate should be between 0 and 1.")
304
+
305
+ if not (0 <= attn_drop_rate <= 1):
306
+ raise ValueError("attention dropout rate should be between 0 and 1.")
307
+
308
+ if not (0 <= dropout_path_rate <= 1):
309
+ raise ValueError("drop path rate should be between 0 and 1.")
310
+
311
+ if feature_size % 12 != 0:
312
+ raise ValueError("feature_size should be divisible by 12.")
313
+
314
+ self.normalize = normalize
315
+
316
+ self.encoder = SwinUNETR_Enc(
317
+ img_size,
318
+ in_channels,
319
+ depths,
320
+ num_heads,
321
+ feature_size,
322
+ norm_name,
323
+ drop_rate,
324
+ attn_drop_rate,
325
+ dropout_path_rate,
326
+ normalize,
327
+ use_checkpoint,
328
+ spatial_dims,
329
+ return_skips=True
330
+ )
331
+
332
+ self.decoder5 = UnetrUpBlock( # a transpose conv layer and 2 conv layers
333
+ spatial_dims=spatial_dims,
334
+ in_channels=16 * feature_size,
335
+ out_channels=8 * feature_size,
336
+ kernel_size=3,
337
+ upsample_kernel_size=2,
338
+ norm_name=norm_name,
339
+ res_block=True,
340
+ )
341
+
342
+ self.decoder4 = UnetrUpBlock(
343
+ spatial_dims=spatial_dims,
344
+ in_channels=feature_size * 8,
345
+ out_channels=feature_size * 4,
346
+ kernel_size=3,
347
+ upsample_kernel_size=2,
348
+ norm_name=norm_name,
349
+ res_block=True,
350
+ )
351
+
352
+ self.decoder3 = UnetrUpBlock(
353
+ spatial_dims=spatial_dims,
354
+ in_channels=feature_size * 4,
355
+ out_channels=feature_size * 2,
356
+ kernel_size=3,
357
+ upsample_kernel_size=2,
358
+ norm_name=norm_name,
359
+ res_block=True,
360
+ )
361
+ self.decoder2 = UnetrUpBlock(
362
+ spatial_dims=spatial_dims,
363
+ in_channels=feature_size * 2,
364
+ out_channels=feature_size,
365
+ kernel_size=3,
366
+ upsample_kernel_size=2,
367
+ norm_name=norm_name,
368
+ res_block=True,
369
+ )
370
+
371
+ self.decoder1 = UnetrUpBlock(
372
+ spatial_dims=spatial_dims,
373
+ in_channels=feature_size,
374
+ out_channels=feature_size,
375
+ kernel_size=3,
376
+ upsample_kernel_size=2,
377
+ norm_name=norm_name,
378
+ res_block=True,
379
+ )
380
+
381
+ def forward(self, x_in):
382
+ enc0, enc1, enc2, enc3, enc4, dec4 = self.encoder(x_in)
383
+
384
+ dec3 = self.decoder5(dec4, enc4)
385
+ dec2 = self.decoder4(dec3, enc3)
386
+ dec1 = self.decoder3(dec2, enc2)
387
+ dec0 = self.decoder2(dec1, enc1)
388
+ out = self.decoder1(dec0, enc0)
389
+ # print(dec3.shape, dec2.shape, dec1.shape, dec0.shape, out.shape)
390
+ # torch.Size([6, 384, 4, 4, 4]) torch.Size([6, 192, 8, 8, 8]) torch.Size([6, 96, 16, 16, 16])
391
+ # torch.Size([6, 48, 32, 32, 32]) torch.Size([6, 48, 64, 64, 64])
392
+
393
+ if self.deep_supervision:
394
+ out_ls = [out, dec0, dec1, dec2, dec3]
395
+ else:
396
+ out_ls = [out]
397
+
398
+ if self.return_skips:
399
+ skips = [enc0, enc1, enc2, enc3, enc4, dec4]
400
+ else:
401
+ skips = [dec4]
402
+
403
+ return skips, out_ls
404
+
405
+
406
+ def window_partition(x, window_size):
407
+ """window partition operation based on: "Liu et al.,
408
+ Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
409
+ <https://arxiv.org/abs/2103.14030>"
410
+ https://github.com/microsoft/Swin-Transformer
411
+ Args:
412
+ x: input tensor.
413
+ window_size: local window size.
414
+ """
415
+ x_shape = x.size()
416
+ if len(x_shape) == 5:
417
+ b, d, h, w, c = x_shape
418
+ x = x.view(
419
+ b,
420
+ d // window_size[0],
421
+ window_size[0],
422
+ h // window_size[1],
423
+ window_size[1],
424
+ w // window_size[2],
425
+ window_size[2],
426
+ c,
427
+ )
428
+ windows = (
429
+ x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size[0] * window_size[1] * window_size[2], c)
430
+ )
431
+ elif len(x_shape) == 4:
432
+ b, h, w, c = x.shape
433
+ x = x.view(b, h // window_size[0], window_size[0], w // window_size[1], window_size[1], c)
434
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0] * window_size[1], c)
435
+ return windows
436
+
437
+
438
+ def window_reverse(windows, window_size, dims):
439
+ """window reverse operation based on: "Liu et al.,
440
+ Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
441
+ <https://arxiv.org/abs/2103.14030>"
442
+ https://github.com/microsoft/Swin-Transformer
443
+ Args:
444
+ windows: windows tensor.
445
+ window_size: local window size.
446
+ dims: dimension values.
447
+ """
448
+ if len(dims) == 4:
449
+ b, d, h, w = dims
450
+ x = windows.view(
451
+ b,
452
+ d // window_size[0],
453
+ h // window_size[1],
454
+ w // window_size[2],
455
+ window_size[0],
456
+ window_size[1],
457
+ window_size[2],
458
+ -1,
459
+ )
460
+ x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(b, d, h, w, -1)
461
+
462
+ elif len(dims) == 3:
463
+ b, h, w = dims
464
+ x = windows.view(b, h // window_size[0], w // window_size[0], window_size[0], window_size[1], -1)
465
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1)
466
+ return x
467
+
468
+
469
+ def get_window_size(x_size, window_size, shift_size=None):
470
+ """Computing window size based on: "Liu et al.,
471
+ Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
472
+ <https://arxiv.org/abs/2103.14030>"
473
+ https://github.com/microsoft/Swin-Transformer
474
+ Args:
475
+ x_size: input size.
476
+ window_size: local window size.
477
+ shift_size: window shifting size.
478
+ """
479
+
480
+ use_window_size = list(window_size)
481
+ if shift_size is not None:
482
+ use_shift_size = list(shift_size)
483
+ for i in range(len(x_size)):
484
+ if x_size[i] <= window_size[i]:
485
+ use_window_size[i] = x_size[i]
486
+ if shift_size is not None:
487
+ use_shift_size[i] = 0
488
+
489
+ if shift_size is None:
490
+ return tuple(use_window_size)
491
+ else:
492
+ return tuple(use_window_size), tuple(use_shift_size)
493
+
494
+
495
+ class WindowAttention(nn.Module):
496
+ """
497
+ Window based multi-head self attention module with relative position bias based on: "Liu et al.,
498
+ Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
499
+ <https://arxiv.org/abs/2103.14030>"
500
+ https://github.com/microsoft/Swin-Transformer
501
+ """
502
+
503
+ def __init__(
504
+ self,
505
+ dim: int,
506
+ num_heads: int,
507
+ window_size: Sequence[int],
508
+ qkv_bias: bool = False,
509
+ attn_drop: float = 0.0,
510
+ proj_drop: float = 0.0,
511
+ ) -> None:
512
+ """
513
+ Args:
514
+ dim: number of feature channels.
515
+ num_heads: number of attention heads.
516
+ window_size: local window size.
517
+ qkv_bias: add a learnable bias to query, key, value.
518
+ attn_drop: attention dropout rate.
519
+ proj_drop: dropout rate of output.
520
+ """
521
+
522
+ super().__init__()
523
+ self.dim = dim
524
+ self.window_size = window_size
525
+ self.num_heads = num_heads
526
+ head_dim = dim // num_heads
527
+ self.scale = head_dim**-0.5
528
+ mesh_args = torch.meshgrid.__kwdefaults__
529
+
530
+ if len(self.window_size) == 3:
531
+ self.relative_position_bias_table = nn.Parameter(
532
+ torch.zeros(
533
+ (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1),
534
+ num_heads,
535
+ )
536
+ )
537
+ coords_d = torch.arange(self.window_size[0])
538
+ coords_h = torch.arange(self.window_size[1])
539
+ coords_w = torch.arange(self.window_size[2])
540
+ if mesh_args is not None:
541
+ coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w, indexing="ij"))
542
+ else:
543
+ coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w))
544
+ coords_flatten = torch.flatten(coords, 1)
545
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
546
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous()
547
+ relative_coords[:, :, 0] += self.window_size[0] - 1
548
+ relative_coords[:, :, 1] += self.window_size[1] - 1
549
+ relative_coords[:, :, 2] += self.window_size[2] - 1
550
+ relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1)
551
+ relative_coords[:, :, 1] *= 2 * self.window_size[2] - 1
552
+ elif len(self.window_size) == 2:
553
+ self.relative_position_bias_table = nn.Parameter(
554
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
555
+ )
556
+ coords_h = torch.arange(self.window_size[0])
557
+ coords_w = torch.arange(self.window_size[1])
558
+ if mesh_args is not None:
559
+ coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"))
560
+ else:
561
+ coords = torch.stack(torch.meshgrid(coords_h, coords_w))
562
+ coords_flatten = torch.flatten(coords, 1)
563
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
564
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous()
565
+ relative_coords[:, :, 0] += self.window_size[0] - 1
566
+ relative_coords[:, :, 1] += self.window_size[1] - 1
567
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
568
+
569
+ relative_position_index = relative_coords.sum(-1)
570
+ self.register_buffer("relative_position_index", relative_position_index)
571
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
572
+ self.attn_drop = nn.Dropout(attn_drop)
573
+ self.proj = nn.Linear(dim, dim)
574
+ self.proj_drop = nn.Dropout(proj_drop)
575
+ trunc_normal_(self.relative_position_bias_table, std=0.02)
576
+ self.softmax = nn.Softmax(dim=-1)
577
+
578
+ def forward(self, x, mask):
579
+ b, n, c = x.shape
580
+ qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4)
581
+ q, k, v = qkv[0], qkv[1], qkv[2]
582
+ q = q * self.scale
583
+ attn = q @ k.transpose(-2, -1)
584
+ relative_position_bias = self.relative_position_bias_table[
585
+ self.relative_position_index.clone()[:n, :n].reshape(-1)
586
+ ].reshape(n, n, -1)
587
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
588
+ attn = attn + relative_position_bias.unsqueeze(0)
589
+ if mask is not None:
590
+ nw = mask.shape[0]
591
+ attn = attn.view(b // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0)
592
+ attn = attn.view(-1, self.num_heads, n, n)
593
+ attn = self.softmax(attn)
594
+ else:
595
+ attn = self.softmax(attn)
596
+
597
+ attn = self.attn_drop(attn)
598
+ x = (attn @ v).transpose(1, 2).reshape(b, n, c)
599
+ x = self.proj(x)
600
+ x = self.proj_drop(x)
601
+ return x
602
+
603
+
604
+ class SwinTransformerBlock(nn.Module):
605
+ """
606
+ Swin Transformer block based on: "Liu et al.,
607
+ Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
608
+ <https://arxiv.org/abs/2103.14030>"
609
+ https://github.com/microsoft/Swin-Transformer
610
+ """
611
+
612
+ def __init__(
613
+ self,
614
+ dim: int,
615
+ num_heads: int,
616
+ window_size: Sequence[int],
617
+ shift_size: Sequence[int],
618
+ mlp_ratio: float = 4.0,
619
+ qkv_bias: bool = True,
620
+ drop: float = 0.0,
621
+ attn_drop: float = 0.0,
622
+ drop_path: float = 0.0,
623
+ act_layer: str = "GELU",
624
+ norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore
625
+ use_checkpoint: bool = False,
626
+ ) -> None:
627
+ """
628
+ Args:
629
+ dim: number of feature channels.
630
+ num_heads: number of attention heads.
631
+ window_size: local window size.
632
+ shift_size: window shift size.
633
+ mlp_ratio: ratio of mlp hidden dim to embedding dim.
634
+ qkv_bias: add a learnable bias to query, key, value.
635
+ drop: dropout rate.
636
+ attn_drop: attention dropout rate.
637
+ drop_path: stochastic depth rate.
638
+ act_layer: activation layer.
639
+ norm_layer: normalization layer.
640
+ use_checkpoint: use gradient checkpointing for reduced memory usage.
641
+ """
642
+
643
+ super().__init__()
644
+ self.dim = dim
645
+ self.num_heads = num_heads
646
+ self.window_size = window_size
647
+ self.shift_size = shift_size
648
+ self.mlp_ratio = mlp_ratio
649
+ self.use_checkpoint = use_checkpoint
650
+ self.norm1 = norm_layer(dim)
651
+ self.attn = WindowAttention(
652
+ dim,
653
+ window_size=self.window_size,
654
+ num_heads=num_heads,
655
+ qkv_bias=qkv_bias,
656
+ attn_drop=attn_drop,
657
+ proj_drop=drop,
658
+ )
659
+
660
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
661
+ self.norm2 = norm_layer(dim)
662
+ mlp_hidden_dim = int(dim * mlp_ratio)
663
+ self.mlp = Mlp(hidden_size=dim, mlp_dim=mlp_hidden_dim, act=act_layer, dropout_rate=drop, dropout_mode="swin")
664
+
665
+ def forward_part1(self, x, mask_matrix):
666
+ x_shape = x.size()
667
+ x = self.norm1(x)
668
+ if len(x_shape) == 5:
669
+ b, d, h, w, c = x.shape
670
+ window_size, shift_size = get_window_size((d, h, w), self.window_size, self.shift_size)
671
+ pad_l = pad_t = pad_d0 = 0
672
+ pad_d1 = (window_size[0] - d % window_size[0]) % window_size[0]
673
+ pad_b = (window_size[1] - h % window_size[1]) % window_size[1]
674
+ pad_r = (window_size[2] - w % window_size[2]) % window_size[2]
675
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1))
676
+ _, dp, hp, wp, _ = x.shape
677
+ dims = [b, dp, hp, wp]
678
+
679
+ elif len(x_shape) == 4:
680
+ b, h, w, c = x.shape
681
+ window_size, shift_size = get_window_size((h, w), self.window_size, self.shift_size)
682
+ pad_l = pad_t = 0
683
+ pad_r = (window_size[0] - h % window_size[0]) % window_size[0]
684
+ pad_b = (window_size[1] - w % window_size[1]) % window_size[1]
685
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
686
+ _, hp, wp, _ = x.shape
687
+ dims = [b, hp, wp]
688
+
689
+ if any(i > 0 for i in shift_size):
690
+ if len(x_shape) == 5:
691
+ shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3))
692
+ elif len(x_shape) == 4:
693
+ shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2))
694
+ attn_mask = mask_matrix
695
+ else:
696
+ shifted_x = x
697
+ attn_mask = None
698
+ x_windows = window_partition(shifted_x, window_size)
699
+ attn_windows = self.attn(x_windows, mask=attn_mask)
700
+ attn_windows = attn_windows.view(-1, *(window_size + (c,)))
701
+ shifted_x = window_reverse(attn_windows, window_size, dims)
702
+ if any(i > 0 for i in shift_size):
703
+ if len(x_shape) == 5:
704
+ x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3))
705
+ elif len(x_shape) == 4:
706
+ x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1]), dims=(1, 2))
707
+ else:
708
+ x = shifted_x
709
+
710
+ if len(x_shape) == 5:
711
+ if pad_d1 > 0 or pad_r > 0 or pad_b > 0:
712
+ x = x[:, :d, :h, :w, :].contiguous()
713
+ elif len(x_shape) == 4:
714
+ if pad_r > 0 or pad_b > 0:
715
+ x = x[:, :h, :w, :].contiguous()
716
+
717
+ return x
718
+
719
+ def forward_part2(self, x):
720
+ return self.drop_path(self.mlp(self.norm2(x)))
721
+
722
+ def load_from(self, weights, n_block, layer):
723
+ root = f"module.{layer}.0.blocks.{n_block}."
724
+ block_names = [
725
+ "norm1.weight",
726
+ "norm1.bias",
727
+ "attn.relative_position_bias_table",
728
+ "attn.relative_position_index",
729
+ "attn.qkv.weight",
730
+ "attn.qkv.bias",
731
+ "attn.proj.weight",
732
+ "attn.proj.bias",
733
+ "norm2.weight",
734
+ "norm2.bias",
735
+ "mlp.fc1.weight",
736
+ "mlp.fc1.bias",
737
+ "mlp.fc2.weight",
738
+ "mlp.fc2.bias",
739
+ ]
740
+ with torch.no_grad():
741
+ self.norm1.weight.copy_(weights["state_dict"][root + block_names[0]])
742
+ self.norm1.bias.copy_(weights["state_dict"][root + block_names[1]])
743
+ self.attn.relative_position_bias_table.copy_(weights["state_dict"][root + block_names[2]])
744
+ self.attn.relative_position_index.copy_(weights["state_dict"][root + block_names[3]])
745
+ self.attn.qkv.weight.copy_(weights["state_dict"][root + block_names[4]])
746
+ self.attn.qkv.bias.copy_(weights["state_dict"][root + block_names[5]])
747
+ self.attn.proj.weight.copy_(weights["state_dict"][root + block_names[6]])
748
+ self.attn.proj.bias.copy_(weights["state_dict"][root + block_names[7]])
749
+ self.norm2.weight.copy_(weights["state_dict"][root + block_names[8]])
750
+ self.norm2.bias.copy_(weights["state_dict"][root + block_names[9]])
751
+ self.mlp.linear1.weight.copy_(weights["state_dict"][root + block_names[10]])
752
+ self.mlp.linear1.bias.copy_(weights["state_dict"][root + block_names[11]])
753
+ self.mlp.linear2.weight.copy_(weights["state_dict"][root + block_names[12]])
754
+ self.mlp.linear2.bias.copy_(weights["state_dict"][root + block_names[13]])
755
+
756
+ def forward(self, x, mask_matrix):
757
+ shortcut = x
758
+ if self.use_checkpoint:
759
+ x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix)
760
+ else:
761
+ x = self.forward_part1(x, mask_matrix)
762
+ x = shortcut + self.drop_path(x)
763
+ if self.use_checkpoint:
764
+ x = x + checkpoint.checkpoint(self.forward_part2, x)
765
+ else:
766
+ x = x + self.forward_part2(x)
767
+ return x
768
+
769
+
770
+ class PatchMerging(nn.Module):
771
+ """
772
+ Patch merging layer based on: "Liu et al.,
773
+ Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
774
+ <https://arxiv.org/abs/2103.14030>"
775
+ https://github.com/microsoft/Swin-Transformer
776
+ """
777
+
778
+ def __init__(
779
+ self, dim: int, norm_layer: Type[LayerNorm] = nn.LayerNorm, spatial_dims: int = 3
780
+ ) -> None: # type: ignore
781
+ """
782
+ Args:
783
+ dim: number of feature channels.
784
+ norm_layer: normalization layer.
785
+ spatial_dims: number of spatial dims.
786
+ """
787
+
788
+ super().__init__()
789
+ self.dim = dim
790
+ if spatial_dims == 3:
791
+ self.reduction = nn.Linear(8 * dim, 2 * dim, bias=False)
792
+ self.norm = norm_layer(8 * dim)
793
+ elif spatial_dims == 2:
794
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
795
+ self.norm = norm_layer(4 * dim)
796
+
797
+ def forward(self, x):
798
+
799
+ x_shape = x.size()
800
+ if len(x_shape) == 5:
801
+ b, d, h, w, c = x_shape
802
+ pad_input = (h % 2 == 1) or (w % 2 == 1) or (d % 2 == 1)
803
+ if pad_input:
804
+ x = F.pad(x, (0, 0, 0, d % 2, 0, w % 2, 0, h % 2))
805
+ x0 = x[:, 0::2, 0::2, 0::2, :]
806
+ x1 = x[:, 1::2, 0::2, 0::2, :]
807
+ x2 = x[:, 0::2, 1::2, 0::2, :]
808
+ x3 = x[:, 0::2, 0::2, 1::2, :]
809
+ x4 = x[:, 1::2, 0::2, 1::2, :]
810
+ x5 = x[:, 0::2, 1::2, 0::2, :]
811
+ x6 = x[:, 0::2, 0::2, 1::2, :]
812
+ x7 = x[:, 1::2, 1::2, 1::2, :]
813
+ x = torch.cat([x0, x1, x2, x3, x4, x5, x6, x7], -1)
814
+
815
+ elif len(x_shape) == 4:
816
+ b, h, w, c = x_shape
817
+ pad_input = (h % 2 == 1) or (w % 2 == 1)
818
+ if pad_input:
819
+ x = F.pad(x, (0, 0, 0, w % 2, 0, h % 2))
820
+ x0 = x[:, 0::2, 0::2, :]
821
+ x1 = x[:, 1::2, 0::2, :]
822
+ x2 = x[:, 0::2, 1::2, :]
823
+ x3 = x[:, 1::2, 1::2, :]
824
+ x = torch.cat([x0, x1, x2, x3], -1)
825
+
826
+ x = self.norm(x)
827
+ x = self.reduction(x)
828
+ return x
829
+
830
+
831
+ def compute_mask(dims, window_size, shift_size, device):
832
+ """Computing region masks based on: "Liu et al.,
833
+ Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
834
+ <https://arxiv.org/abs/2103.14030>"
835
+ https://github.com/microsoft/Swin-Transformer
836
+ Args:
837
+ dims: dimension values.
838
+ window_size: local window size.
839
+ shift_size: shift size.
840
+ device: device.
841
+ """
842
+
843
+ cnt = 0
844
+
845
+ if len(dims) == 3:
846
+ d, h, w = dims
847
+ img_mask = torch.zeros((1, d, h, w, 1), device=device)
848
+ for d in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0], None):
849
+ for h in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1], None):
850
+ for w in slice(-window_size[2]), slice(-window_size[2], -shift_size[2]), slice(-shift_size[2], None):
851
+ img_mask[:, d, h, w, :] = cnt
852
+ cnt += 1
853
+
854
+ elif len(dims) == 2:
855
+ h, w = dims
856
+ img_mask = torch.zeros((1, h, w, 1), device=device)
857
+ for h in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0], None):
858
+ for w in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1], None):
859
+ img_mask[:, h, w, :] = cnt
860
+ cnt += 1
861
+
862
+ mask_windows = window_partition(img_mask, window_size)
863
+ mask_windows = mask_windows.squeeze(-1)
864
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
865
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
866
+
867
+ return attn_mask
868
+
869
+
870
+ class BasicLayer(nn.Module):
871
+ """
872
+ Basic Swin Transformer layer in one stage based on: "Liu et al.,
873
+ Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
874
+ <https://arxiv.org/abs/2103.14030>"
875
+ https://github.com/microsoft/Swin-Transformer
876
+ """
877
+
878
+ def __init__(
879
+ self,
880
+ dim: int,
881
+ depth: int,
882
+ num_heads: int,
883
+ window_size: Sequence[int],
884
+ drop_path: list,
885
+ mlp_ratio: float = 4.0,
886
+ qkv_bias: bool = False,
887
+ drop: float = 0.0,
888
+ attn_drop: float = 0.0,
889
+ norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore
890
+ downsample: isinstance = None, # type: ignore
891
+ use_checkpoint: bool = False,
892
+ ) -> None:
893
+ """
894
+ Args:
895
+ dim: number of feature channels.
896
+ depths: number of layers in each stage.
897
+ num_heads: number of attention heads.
898
+ window_size: local window size.
899
+ drop_path: stochastic depth rate.
900
+ mlp_ratio: ratio of mlp hidden dim to embedding dim.
901
+ qkv_bias: add a learnable bias to query, key, value.
902
+ drop: dropout rate.
903
+ attn_drop: attention dropout rate.
904
+ norm_layer: normalization layer.
905
+ downsample: downsample layer at the end of the layer.
906
+ use_checkpoint: use gradient checkpointing for reduced memory usage.
907
+ """
908
+
909
+ super().__init__()
910
+ self.window_size = window_size
911
+ self.shift_size = tuple(i // 2 for i in window_size)
912
+ self.no_shift = tuple(0 for i in window_size)
913
+ self.depth = depth
914
+ self.use_checkpoint = use_checkpoint
915
+ self.blocks = nn.ModuleList(
916
+ [
917
+ SwinTransformerBlock(
918
+ dim=dim,
919
+ num_heads=num_heads,
920
+ window_size=self.window_size,
921
+ shift_size=self.no_shift if (i % 2 == 0) else self.shift_size,
922
+ mlp_ratio=mlp_ratio,
923
+ qkv_bias=qkv_bias,
924
+ drop=drop,
925
+ attn_drop=attn_drop,
926
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
927
+ norm_layer=norm_layer,
928
+ use_checkpoint=use_checkpoint,
929
+ )
930
+ for i in range(depth)
931
+ ]
932
+ )
933
+ self.downsample = downsample
934
+ if self.downsample is not None:
935
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer, spatial_dims=len(self.window_size))
936
+
937
+ def forward(self, x):
938
+ x_shape = x.size()
939
+ if len(x_shape) == 5:
940
+ b, c, d, h, w = x_shape
941
+ window_size, shift_size = get_window_size((d, h, w), self.window_size, self.shift_size)
942
+ x = rearrange(x, "b c d h w -> b d h w c")
943
+ dp = int(np.ceil(d / window_size[0])) * window_size[0]
944
+ hp = int(np.ceil(h / window_size[1])) * window_size[1]
945
+ wp = int(np.ceil(w / window_size[2])) * window_size[2]
946
+ attn_mask = compute_mask([dp, hp, wp], window_size, shift_size, x.device)
947
+ for blk in self.blocks:
948
+ x = blk(x, attn_mask)
949
+ x = x.view(b, d, h, w, -1)
950
+ if self.downsample is not None:
951
+ x = self.downsample(x)
952
+ x = rearrange(x, "b d h w c -> b c d h w")
953
+
954
+ elif len(x_shape) == 4:
955
+ b, c, h, w = x_shape
956
+ window_size, shift_size = get_window_size((h, w), self.window_size, self.shift_size)
957
+ x = rearrange(x, "b c h w -> b h w c")
958
+ hp = int(np.ceil(h / window_size[0])) * window_size[0]
959
+ wp = int(np.ceil(w / window_size[1])) * window_size[1]
960
+ attn_mask = compute_mask([hp, wp], window_size, shift_size, x.device)
961
+ for blk in self.blocks:
962
+ x = blk(x, attn_mask)
963
+ x = x.view(b, h, w, -1)
964
+ if self.downsample is not None:
965
+ x = self.downsample(x)
966
+ x = rearrange(x, "b h w c -> b c h w")
967
+ return x
968
+
969
+
970
+ class SwinTransformer(nn.Module):
971
+ """
972
+ Swin Transformer based on: "Liu et al.,
973
+ Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
974
+ <https://arxiv.org/abs/2103.14030>"
975
+ https://github.com/microsoft/Swin-Transformer
976
+ """
977
+
978
+ def __init__(
979
+ self,
980
+ in_chans: int,
981
+ embed_dim: int,
982
+ window_size: Sequence[int],
983
+ patch_size: Sequence[int],
984
+ depths: Sequence[int],
985
+ num_heads: Sequence[int],
986
+ mlp_ratio: float = 4.0,
987
+ qkv_bias: bool = True,
988
+ drop_rate: float = 0.0,
989
+ attn_drop_rate: float = 0.0,
990
+ drop_path_rate: float = 0.0,
991
+ norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore
992
+ patch_norm: bool = False,
993
+ use_checkpoint: bool = False,
994
+ spatial_dims: int = 3,
995
+ ) -> None:
996
+ """
997
+ Args:
998
+ in_chans: dimension of input channels.
999
+ embed_dim: number of linear projection output channels.
1000
+ window_size: local window size.
1001
+ patch_size: patch size.
1002
+ depths: number of layers in each stage.
1003
+ num_heads: number of attention heads.
1004
+ mlp_ratio: ratio of mlp hidden dim to embedding dim.
1005
+ qkv_bias: add a learnable bias to query, key, value.
1006
+ drop_rate: dropout rate.
1007
+ attn_drop_rate: attention dropout rate.
1008
+ drop_path_rate: stochastic depth rate.
1009
+ norm_layer: normalization layer.
1010
+ patch_norm: add normalization after patch embedding.
1011
+ use_checkpoint: use gradient checkpointing for reduced memory usage.
1012
+ spatial_dims: spatial dimension.
1013
+ """
1014
+
1015
+ super().__init__()
1016
+ self.num_layers = len(depths)
1017
+ self.embed_dim = embed_dim
1018
+ self.patch_norm = patch_norm
1019
+ self.window_size = window_size
1020
+ self.patch_size = patch_size
1021
+ self.patch_embed = PatchEmbed(
1022
+ patch_size=self.patch_size,
1023
+ in_chans=in_chans,
1024
+ embed_dim=embed_dim,
1025
+ norm_layer=norm_layer if self.patch_norm else None, # type: ignore
1026
+ spatial_dims=spatial_dims,
1027
+ )
1028
+ self.pos_drop = nn.Dropout(p=drop_rate)
1029
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
1030
+ self.layers1 = nn.ModuleList()
1031
+ self.layers2 = nn.ModuleList()
1032
+ self.layers3 = nn.ModuleList()
1033
+ self.layers4 = nn.ModuleList()
1034
+ for i_layer in range(self.num_layers):
1035
+ layer = BasicLayer(
1036
+ dim=int(embed_dim * 2**i_layer),
1037
+ depth=depths[i_layer],
1038
+ num_heads=num_heads[i_layer],
1039
+ window_size=self.window_size,
1040
+ drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
1041
+ mlp_ratio=mlp_ratio,
1042
+ qkv_bias=qkv_bias,
1043
+ drop=drop_rate,
1044
+ attn_drop=attn_drop_rate,
1045
+ norm_layer=norm_layer,
1046
+ downsample=PatchMerging,
1047
+ use_checkpoint=use_checkpoint,
1048
+ )
1049
+ if i_layer == 0:
1050
+ self.layers1.append(layer)
1051
+ elif i_layer == 1:
1052
+ self.layers2.append(layer)
1053
+ elif i_layer == 2:
1054
+ self.layers3.append(layer)
1055
+ elif i_layer == 3:
1056
+ self.layers4.append(layer)
1057
+ self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
1058
+
1059
+ def proj_out(self, x, normalize=False):
1060
+ if normalize:
1061
+ x_shape = x.size()
1062
+ if len(x_shape) == 5:
1063
+ n, ch, d, h, w = x_shape
1064
+ x = rearrange(x, "n c d h w -> n d h w c")
1065
+ x = F.layer_norm(x, [ch])
1066
+ x = rearrange(x, "n d h w c -> n c d h w")
1067
+ elif len(x_shape) == 4:
1068
+ n, ch, h, w = x_shape
1069
+ x = rearrange(x, "n c h w -> n h w c")
1070
+ x = F.layer_norm(x, [ch])
1071
+ x = rearrange(x, "n h w c -> n c h w")
1072
+ return x
1073
+
1074
+ def forward(self, x, normalize=True):
1075
+ x0 = self.patch_embed(x)
1076
+ x0 = self.pos_drop(x0)
1077
+ x0_out = self.proj_out(x0, normalize)
1078
+ x1 = self.layers1[0](x0.contiguous())
1079
+ x1_out = self.proj_out(x1, normalize)
1080
+ x2 = self.layers2[0](x1.contiguous())
1081
+ x2_out = self.proj_out(x2, normalize)
1082
+ x3 = self.layers3[0](x2.contiguous())
1083
+ x3_out = self.proj_out(x3, normalize)
1084
+ x4 = self.layers4[0](x3.contiguous())
1085
+ x4_out = self.proj_out(x4, normalize)
1086
+ return [x0_out, x1_out, x2_out, x3_out, x4_out]
1087
+
1088
+ if __name__ == '__main__':
1089
+ import os
1090
+ def get_parameter_number(model):
1091
+ total_num = sum(p.numel() for p in model.parameters())
1092
+ trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
1093
+ return {'Total': total_num, 'Trainable': trainable_num}
1094
+
1095
+ model = SwinUNETR(
1096
+ img_size=[288, 288, 96], # the real input should satisfy : d,h,w > 32
1097
+ in_channels=3,
1098
+ feature_size=48,
1099
+ drop_rate=0.0,
1100
+ attn_drop_rate=0.0,
1101
+ dropout_path_rate=0.0,
1102
+ use_checkpoint=False,
1103
+ deep_supervision=True,
1104
+ return_skips=True,
1105
+ ).cuda()
1106
+
1107
+ if is_master():
1108
+ print(f"** UNET ** {get_parameter_number(model)['Total']/1e6}M parameters")
1109
+
1110
+ image = torch.rand((1, 3, 288, 288, 96)).cuda()
1111
+ skips, outs = model(image)
1112
+
1113
+ for s in skips:
1114
+ print(s.shape)
1115
+ for out in outs:
1116
+ print(out.shape)
model/__init__.py ADDED
File without changes
model/base_bert.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+
4
+ from transformers import BertModel, AutoTokenizer
5
+
6
+ class BaseBERT(nn.Module):
7
+ def __init__(self, basebert_checkpoint='bert-base-uncased'):
8
+ super().__init__()
9
+ self.tokenizer = AutoTokenizer.from_pretrained(basebert_checkpoint)
10
+ self.model = BertModel.from_pretrained(basebert_checkpoint)
11
+ self.modality_embed = nn.Embedding(4, 768)
12
+
13
+ def forward(self, text, modality):
14
+ encoded = self.tokenizer(
15
+ text,
16
+ truncation=True,
17
+ padding=True,
18
+ return_tensors='pt',
19
+ max_length=64,
20
+ ).to(device=torch.cuda.current_device())
21
+
22
+ text_feature = self.model(**encoded).last_hidden_state[:, 0, :]
23
+ modality_feature = self.modality_embed(modality)
24
+ text_feature += modality_feature
25
+
26
+ return text_feature
model/build_model.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import time
4
+ import os
5
+ from torch.nn.parallel import DistributedDataParallel as DDP
6
+
7
+ import numpy as np
8
+
9
+ from .maskformer import Maskformer
10
+
11
+ from train.dist import is_master
12
+
13
+
14
+ def get_parameter_number(model):
15
+ total_num = sum(p.numel() for p in model.parameters())
16
+ trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
17
+ return {'Total': total_num, 'Trainable': trainable_num}
18
+
19
+
20
+ def build_maskformer(args, device, gpu_id):
21
+ model = Maskformer(args.vision_backbone, args.input_channels, args.crop_size, args.patch_size, args.deep_supervision)
22
+
23
+ model = model.to(device)
24
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
25
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu_id], find_unused_parameters=True)
26
+
27
+ def get_parameter_number(model):
28
+ total_num = sum(p.numel() for p in model.parameters())
29
+ trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
30
+ return {'Total': total_num, 'Trainable': trainable_num}
31
+
32
+ if is_master():
33
+ print(f"** MODEL ** {get_parameter_number(model)['Total']/1e6}M parameters")
34
+
35
+ return model
36
+
37
+
38
+ def load_checkpoint(checkpoint_file,
39
+ resume,
40
+ partial_load,
41
+ model,
42
+ device,
43
+ optimizer=None,
44
+ ):
45
+
46
+ if is_master():
47
+ print('** CHECKPOINT ** : Load checkpoint from %s' % (checkpoint_file))
48
+
49
+ checkpoint = torch.load(checkpoint_file, map_location=device)
50
+
51
+ # load part of the checkpoint
52
+ if partial_load:
53
+ model_dict = model.state_dict()
54
+ # check difference
55
+ unexpected_state_dict = [k for k in checkpoint['model_state_dict'].keys() if k not in model_dict.keys()]
56
+ missing_state_dict = [k for k in model_dict.keys() if k not in checkpoint['model_state_dict'].keys()]
57
+ unmatchd_state_dict = [k for k,v in checkpoint['model_state_dict'].items() if k in model_dict.keys() and v.shape != model_dict[k].shape]
58
+ # load partial parameters
59
+ state_dict = {k:v for k,v in checkpoint['model_state_dict'].items() if k in model_dict.keys() and v.shape == model_dict[k].shape}
60
+ model_dict.update(state_dict)
61
+ model.load_state_dict(model_dict)
62
+ if is_master():
63
+ print('The following parameters are unexpected in SAT checkpoint:\n', unexpected_state_dict)
64
+ print('The following parameters are missing in SAT checkpoint:\n', missing_state_dict)
65
+ print('The following parameters have different shapes in SAT checkpoint:\n', unmatchd_state_dict)
66
+ print('The following parameters are loaded in SAT:\n', state_dict.keys())
67
+ else:
68
+ model.load_state_dict(checkpoint['model_state_dict'])
69
+
70
+ # if resume, load optimizer and step
71
+ if resume:
72
+ try:
73
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
74
+ except:
75
+ print('Optimizer state dict not matched, skip loading optimizer state dict')
76
+ pass
77
+ start_step = int(checkpoint['step']) + 1
78
+ print('Resume from step %d' % (start_step))
79
+ else:
80
+ start_step = 1
81
+
82
+ return model, optimizer, start_step
83
+
84
+
85
+ def inherit_knowledge_encoder(knowledge_encoder_checkpoint,
86
+ model,
87
+ device
88
+ ):
89
+ # inherit unet encoder and multiscale feature projection layer from knowledge encoder
90
+ checkpoint = torch.load(knowledge_encoder_checkpoint, map_location=device)
91
+
92
+ model_dict = model.state_dict()
93
+ visual_encoder_state_dict = {k.replace('atlas_tower', 'backbone'):v for k,v in checkpoint['model_state_dict'].items() if 'atlas_tower.encoder' in k} # encoder部分
94
+ model_dict.update(visual_encoder_state_dict)
95
+ proj_state_dict = {k.replace('atlas_tower.', ''):v for k,v in checkpoint['model_state_dict'].items() if 'atlas_tower.projection_layer' in k} # projection layer部分
96
+ model_dict.update(proj_state_dict)
97
+ model.load_state_dict(model_dict)
98
+
99
+ if is_master():
100
+ print('** CHECKPOINT ** : Inherit pretrained unet encoder from %s' % (knowledge_encoder_checkpoint))
101
+ print('The following parameters are loaded in SAT:\n', list(visual_encoder_state_dict.keys())+list(proj_state_dict.keys()))
102
+
103
+ return model
model/dynamic-network-architectures-main/.gitignore ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ env/
12
+ build/
13
+ develop-eggs/
14
+ dist/
15
+ downloads/
16
+ eggs/
17
+ .eggs/
18
+ lib/
19
+ lib64/
20
+ parts/
21
+ sdist/
22
+ var/
23
+ *.egg-info/
24
+ .installed.cfg
25
+ *.egg
26
+
27
+ # PyInstaller
28
+ # Usually these files are written by a python script from a template
29
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
30
+ *.manifest
31
+ *.spec
32
+
33
+ # Installer logs
34
+ pip-log.txt
35
+ pip-delete-this-directory.txt
36
+
37
+ # Unit test / coverage reports
38
+ htmlcov/
39
+ .tox/
40
+ .coverage
41
+ .coverage.*
42
+ .cache
43
+ nosetests.xml
44
+ coverage.xml
45
+ *,cover
46
+ .hypothesis/
47
+
48
+ # Translations
49
+ *.mo
50
+ *.pot
51
+
52
+ # Django stuff:
53
+ *.log
54
+ local_settings.py
55
+
56
+ # Flask stuff:
57
+ instance/
58
+ .webassets-cache
59
+
60
+ # Scrapy stuff:
61
+ .scrapy
62
+
63
+ # Sphinx documentation
64
+ docs/_build/
65
+
66
+ # PyBuilder
67
+ target/
68
+
69
+ # IPython Notebook
70
+ .ipynb_checkpoints
71
+
72
+ # pyenv
73
+ .python-version
74
+
75
+ # celery beat schedule file
76
+ celerybeat-schedule
77
+
78
+ # dotenv
79
+ .env
80
+
81
+ # virtualenv
82
+ venv/
83
+ ENV/
84
+
85
+ # Spyder project settings
86
+ .spyderproject
87
+
88
+ # Rope project settings
89
+ .ropeproject
90
+
91
+ *.memmap
92
+ *.zip
93
+ *.npz
94
+ *.npy
95
+ *.jpg
96
+ *.jpeg
97
+ .idea
98
+ *.txt
99
+ .idea/*
100
+ *.nii.gz
101
+ *.nii
102
+ *.tif
103
+ *.bmp
104
+ *.pkl
105
+ *.xml
106
+ *.pkl
107
+ *.pdf
108
+ *.jpg
109
+ *.jpeg
110
+
111
+ *.model
112
+
113
+ cifar_lightning/mlruns*
model/dynamic-network-architectures-main/LICENCE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [2022] [Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
model/dynamic-network-architectures-main/README.md ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dynamic Network Architectures
2
+
3
+ This repository contains several ResNet, U-Net and VGG architectures in pytorch that can be dynamically adapted to a varying number of image dimensions (1D, 2D or 3D) and the number of input channels.
4
+
5
+ ## Available models
6
+ ### ResNet
7
+ We implement the standard [ResNetD](https://arxiv.org/pdf/1812.01187.pdf) 18, 34, 50 and 152. For ResNets 50 and 152 also bottleneck implementations are available. Moreover, adapted versions that are better suited for smaller image sizes such as CIFAR can be used.
8
+
9
+ All models additionally include regularization techniques like [Stochastic Depth](https://arxiv.org/pdf/1603.09382.pdf), [Squeeze & Excitation](https://arxiv.org/pdf/1709.01507.pdf) and [Final Layer Dropout](https://jmlr.org/papers/volume15/srivastava14a/srivastava14a.pdf).
10
+
11
+ ### VGG
12
+ In contrast to the original [VGG](https://arxiv.org/pdf/1409.1556.pdf) implementation we exclude the final fully-connected layers in the end and replace it by additional convolutional layers and only one fully-connected layer in the end. Adapted versions that are better suited for smaller image sizes such as CIFAR can be used.
13
+
14
+ ### U-Net
15
+ For the [U-Net](https://arxiv.org/pdf/1505.04597.pdf) a plain convolutional encoder as well as a residual encoder are available.
16
+
17
+ # Acknowledgements
18
+
19
+ <p align="left">
20
+ <img src="imgs/Logos/HI_Logo.png" width="150"> &nbsp;&nbsp;&nbsp;&nbsp;
21
+ <img src="imgs/Logos/DKFZ_Logo.png" width="500">
22
+ </p>
23
+
24
+ This Repository is developed and maintained by the Applied Computer Vision Lab (ACVL)
25
+ of [Helmholtz Imaging](https://www.helmholtz-imaging.de/).
model/dynamic-network-architectures-main/dynamic_network_architectures.egg-info/PKG-INFO ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.4
2
+ Name: dynamic_network_architectures
3
+ Version: 0.2
4
+ Summary: none
5
+ Author: Fabian Isensee
6
+ Author-email: f.isensee@dkfz.de
7
+ License: private
8
+ License-File: LICENCE
9
+ Requires-Dist: torch>=1.6.0a
10
+ Requires-Dist: numpy
11
+ Dynamic: author
12
+ Dynamic: author-email
13
+ Dynamic: license
14
+ Dynamic: license-file
15
+ Dynamic: requires-dist
16
+ Dynamic: summary
model/dynamic-network-architectures-main/dynamic_network_architectures.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LICENCE
2
+ README.md
3
+ setup.py
4
+ dynamic_network_architectures/__init__.py
5
+ dynamic_network_architectures.egg-info/PKG-INFO
6
+ dynamic_network_architectures.egg-info/SOURCES.txt
7
+ dynamic_network_architectures.egg-info/dependency_links.txt
8
+ dynamic_network_architectures.egg-info/not-zip-safe
9
+ dynamic_network_architectures.egg-info/requires.txt
10
+ dynamic_network_architectures.egg-info/top_level.txt
11
+ dynamic_network_architectures/architectures/__init__.py
12
+ dynamic_network_architectures/architectures/resnet.py
13
+ dynamic_network_architectures/architectures/unet.py
14
+ dynamic_network_architectures/architectures/vgg.py
15
+ dynamic_network_architectures/building_blocks/__init__.py
16
+ dynamic_network_architectures/building_blocks/helper.py
17
+ dynamic_network_architectures/building_blocks/plain_conv_encoder.py
18
+ dynamic_network_architectures/building_blocks/regularization.py
19
+ dynamic_network_architectures/building_blocks/residual.py
20
+ dynamic_network_architectures/building_blocks/residual_encoders.py
21
+ dynamic_network_architectures/building_blocks/simple_conv_blocks.py
22
+ dynamic_network_architectures/building_blocks/unet_decoder.py
23
+ dynamic_network_architectures/initialization/__init__.py
24
+ dynamic_network_architectures/initialization/weight_init.py
model/dynamic-network-architectures-main/dynamic_network_architectures.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
model/dynamic-network-architectures-main/dynamic_network_architectures.egg-info/not-zip-safe ADDED
@@ -0,0 +1 @@
 
 
1
+
model/dynamic-network-architectures-main/dynamic_network_architectures.egg-info/requires.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch>=1.6.0a
2
+ numpy
model/dynamic-network-architectures-main/dynamic_network_architectures.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ dynamic_network_architectures
model/dynamic-network-architectures-main/dynamic_network_architectures/__init__.py ADDED
File without changes
model/dynamic-network-architectures-main/dynamic_network_architectures/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (256 Bytes). View file
 
model/dynamic-network-architectures-main/dynamic_network_architectures/architectures/__init__.py ADDED
File without changes
model/dynamic-network-architectures-main/dynamic_network_architectures/architectures/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (270 Bytes). View file
 
model/dynamic-network-architectures-main/dynamic_network_architectures/architectures/__pycache__/unet.cpython-310.pyc ADDED
Binary file (7.52 kB). View file
 
model/dynamic-network-architectures-main/dynamic_network_architectures/architectures/resnet.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from dynamic_network_architectures.building_blocks.residual_encoders import ResidualEncoder, BottleneckD, BasicBlockD
3
+ from dynamic_network_architectures.building_blocks.helper import get_matching_pool_op, get_default_network_config
4
+ from dynamic_network_architectures.building_blocks.simple_conv_blocks import ConvDropoutNormReLU
5
+ from torch import nn
6
+
7
+ _ResNet_CONFIGS = {
8
+ '18': {'features_per_stage': (64, 128, 256, 512), 'n_blocks_per_stage': (2, 2, 2, 2), 'strides': (1, 2, 2, 2),
9
+ 'block': BasicBlockD, 'bottleneck_channels': None, 'disable_default_stem': True, 'stem_channels': None},
10
+ '34': {'features_per_stage': (64, 128, 256, 512), 'n_blocks_per_stage': (3, 4, 6, 3), 'strides': (1, 2, 2, 2),
11
+ 'block': BasicBlockD, 'bottleneck_channels': None, 'disable_default_stem': True, 'stem_channels': None},
12
+ '50': {'features_per_stage': (64, 128, 256, 512), 'n_blocks_per_stage': (4, 6, 10, 5), 'strides': (1, 2, 2, 2),
13
+ 'block': BasicBlockD, 'bottleneck_channels': None, 'disable_default_stem': True, 'stem_channels': None},
14
+ '152': {'features_per_stage': (64, 128, 256, 512), 'n_blocks_per_stage': (4, 13, 55, 4), 'strides': (1, 2, 2, 2),
15
+ 'block': BasicBlockD, 'bottleneck_channels': None, 'disable_default_stem': True, 'stem_channels': None},
16
+ '50_bn': {'features_per_stage': (256, 512, 1024, 2048), 'n_blocks_per_stage': (3, 4, 6, 3), 'strides': (1, 2, 2, 2),
17
+ 'block': BottleneckD, 'bottleneck_channels': (64, 128, 256, 512), 'disable_default_stem': True,
18
+ 'stem_channels': 64},
19
+ '152_bn': {'features_per_stage': (256, 512, 1024, 2048), 'n_blocks_per_stage': (3, 8, 36, 3),
20
+ 'strides': (1, 2, 2, 2),
21
+ 'block': BottleneckD, 'bottleneck_channels': (64, 128, 256, 512), 'disable_default_stem': True,
22
+ 'stem_channels': 64},
23
+ '18_cifar': {'features_per_stage': (64, 128, 256, 512), 'n_blocks_per_stage': (2, 2, 2, 2), 'strides': (1, 2, 2, 2),
24
+ 'block': BasicBlockD, 'bottleneck_channels': None, 'disable_default_stem': False,
25
+ 'stem_channels': None},
26
+ '34_cifar': {'features_per_stage': (64, 128, 256, 512), 'n_blocks_per_stage': (3, 4, 6, 3), 'strides': (1, 2, 2, 2),
27
+ 'block': BasicBlockD, 'bottleneck_channels': None, 'disable_default_stem': False,
28
+ 'stem_channels': None},
29
+ '50_cifar': {'features_per_stage': (64, 128, 256, 512), 'n_blocks_per_stage': (4, 6, 10, 5),
30
+ 'strides': (1, 2, 2, 2),
31
+ 'block': BasicBlockD, 'bottleneck_channels': None, 'disable_default_stem': False,
32
+ 'stem_channels': None},
33
+ '152_cifar': {'features_per_stage': (64, 128, 256, 512), 'n_blocks_per_stage': (4, 13, 55, 4),
34
+ 'strides': (1, 2, 2, 2),
35
+ 'block': BasicBlockD, 'bottleneck_channels': None, 'disable_default_stem': False,
36
+ 'stem_channels': None},
37
+ '50_cifar_bn': {'features_per_stage': (256, 512, 1024, 2048), 'n_blocks_per_stage': (3, 4, 6, 3),
38
+ 'strides': (1, 2, 2, 2),
39
+ 'block': BottleneckD, 'bottleneck_channels': (64, 128, 256, 512), 'disable_default_stem': False,
40
+ 'stem_channels': 64},
41
+ '152_cifar_bn': {'features_per_stage': (256, 512, 1024, 2048), 'n_blocks_per_stage': (3, 8, 36, 3),
42
+ 'strides': (1, 2, 2, 2),
43
+ 'block': BottleneckD, 'bottleneck_channels': (64, 128, 256, 512), 'disable_default_stem': False,
44
+ 'stem_channels': 64},
45
+ }
46
+
47
+
48
+ class ResNetD(nn.Module):
49
+ def __init__(self, n_classes: int, n_input_channel: int = 3, config='18', input_dimension=2,
50
+ final_layer_dropout=0.0, stochastic_depth_p=0.0, squeeze_excitation=False,
51
+ squeeze_excitation_rd_ratio=1./16):
52
+ """
53
+ Implements ResNetD (https://arxiv.org/pdf/1812.01187.pdf).
54
+ Args:
55
+ n_classes: Number of classes
56
+ n_input_channel: Number of input channels (e.g. 3 for RGB)
57
+ config: Configuration of the ResNet
58
+ input_dimension: Number of dimensions of the data (1, 2 or 3)
59
+ final_layer_dropout: Probability of dropout before the final classifier
60
+ stochastic_depth_p: Stochastic Depth probability
61
+ squeeze_excitation: Whether Squeeze and Excitation should be applied
62
+ squeeze_excitation_rd_ratio: Squeeze and Excitation Reduction Ratio
63
+ Returns:
64
+ ResNet Model
65
+ """
66
+ super().__init__()
67
+ self.input_channels = n_input_channel
68
+ self.cfg = _ResNet_CONFIGS[config]
69
+ self.ops = get_default_network_config(dimension=input_dimension)
70
+ self.final_layer_dropout_p = final_layer_dropout
71
+
72
+ if self.cfg['disable_default_stem']:
73
+ stem_features = self.cfg['stem_channels'] if self.cfg['stem_channels'] is not None else \
74
+ self.cfg['features_per_stage'][0]
75
+ self.stem = self._build_imagenet_stem_D(stem_features)
76
+ encoder_input_features = stem_features
77
+ else:
78
+ encoder_input_features = n_input_channel
79
+ self.stem = None
80
+
81
+ self.encoder = ResidualEncoder(encoder_input_features, n_stages=len(self.cfg['features_per_stage']),
82
+ features_per_stage=self.cfg['features_per_stage'], conv_op=self.ops['conv_op'],
83
+ kernel_sizes=3, strides=self.cfg['strides'],
84
+ n_blocks_per_stage=self.cfg['n_blocks_per_stage'], conv_bias=False,
85
+ norm_op=self.ops['norm_op'], norm_op_kwargs=None, dropout_op=None,
86
+ dropout_op_kwargs=None, nonlin=nn.ReLU,
87
+ nonlin_kwargs={'inplace': True}, block=self.cfg['block'],
88
+ bottleneck_channels=self.cfg['bottleneck_channels'], return_skips=False,
89
+ disable_default_stem=self.cfg['disable_default_stem'],
90
+ stem_channels=self.cfg['stem_channels'],
91
+ stochastic_depth_p=stochastic_depth_p,
92
+ squeeze_excitation=squeeze_excitation,
93
+ squeeze_excitation_reduction_ratio=squeeze_excitation_rd_ratio)
94
+
95
+ self.gap = get_matching_pool_op(conv_op=self.ops['conv_op'], adaptive=True, pool_type='avg')(1)
96
+ self.classifier = nn.Linear(self.cfg['features_per_stage'][-1], n_classes, True)
97
+ self.final_layer_dropout = self.ops['dropout_op'](p=self.final_layer_dropout_p)
98
+
99
+ def forward(self, x):
100
+ if self.stem is not None:
101
+ x = self.stem(x)
102
+ x = self.encoder(x)
103
+ x = self.gap(x)
104
+ x = self.final_layer_dropout(x).squeeze()
105
+
106
+ return self.classifier(x)
107
+
108
+ def _build_imagenet_stem_D(self, stem_features):
109
+ """
110
+ https://arxiv.org/pdf/1812.01187.pdf
111
+
112
+ use 3 3x3(x3) convs instead of one 7x7. Stride is located in first conv.
113
+
114
+ Fig2 b) describes this
115
+ :return:
116
+ """
117
+ c1 = ConvDropoutNormReLU(self.ops['conv_op'], self.input_channels, stem_features, 3, 2, False,
118
+ self.ops['norm_op'], None, None, None, nn.ReLU, {'inplace': True})
119
+ c2 = ConvDropoutNormReLU(self.ops['conv_op'], stem_features, stem_features, 3, 1, False,
120
+ self.ops['norm_op'], None, None, None, nn.ReLU, {'inplace': True})
121
+ c3 = ConvDropoutNormReLU(self.ops['conv_op'], stem_features, stem_features, 3, 1, False,
122
+ self.ops['norm_op'], None, None, None, nn.ReLU, {'inplace': True})
123
+ pl = get_matching_pool_op(conv_op=self.ops['conv_op'], adaptive=False, pool_type='max')(2)
124
+ stem = nn.Sequential(c1, c2, c3, pl)
125
+ return stem
126
+
127
+
128
+ class ResNet18_CIFAR(ResNetD):
129
+ def __init__(self, n_classes: int, n_input_channels: int = 3, input_dimension: int = 2,
130
+ final_layer_dropout: float = 0.0, stochastic_depth_p: float = 0.0, squeeze_excitation: bool = False,
131
+ squeeze_excitation_rd_ratio: float = 1./16):
132
+ super().__init__(n_classes, n_input_channels, config='18_cifar', input_dimension=input_dimension,
133
+ final_layer_dropout=final_layer_dropout, stochastic_depth_p=stochastic_depth_p,
134
+ squeeze_excitation=squeeze_excitation, squeeze_excitation_rd_ratio=squeeze_excitation_rd_ratio)
135
+
136
+ class ResNet34_CIFAR(ResNetD):
137
+ def __init__(self, n_classes: int, n_input_channels: int = 3, input_dimension: int = 2,
138
+ final_layer_dropout: float = 0.0, stochastic_depth_p: float = 0.0, squeeze_excitation: bool = False,
139
+ squeeze_excitation_rd_ratio: float = 1./16):
140
+ super().__init__(n_classes, n_input_channels, config='34_cifar', input_dimension=input_dimension,
141
+ final_layer_dropout=final_layer_dropout, stochastic_depth_p=stochastic_depth_p,
142
+ squeeze_excitation=squeeze_excitation, squeeze_excitation_rd_ratio=squeeze_excitation_rd_ratio)
143
+
144
+ class ResNet50_CIFAR(ResNetD):
145
+ def __init__(self, n_classes: int, n_input_channels: int = 3, input_dimension: int = 2,
146
+ final_layer_dropout: float = 0.0, stochastic_depth_p: float = 0.0, squeeze_excitation: bool = False,
147
+ squeeze_excitation_rd_ratio: float = 1./16):
148
+ super().__init__(n_classes, n_input_channels, config='50_cifar', input_dimension=input_dimension,
149
+ final_layer_dropout=final_layer_dropout, stochastic_depth_p=stochastic_depth_p,
150
+ squeeze_excitation=squeeze_excitation, squeeze_excitation_rd_ratio=squeeze_excitation_rd_ratio)
151
+
152
+ class ResNet152_CIFAR(ResNetD):
153
+ def __init__(self, n_classes: int, n_input_channels: int = 3, input_dimension: int = 2,
154
+ final_layer_dropout: float = 0.0, stochastic_depth_p: float = 0.0, squeeze_excitation: bool = False,
155
+ squeeze_excitation_rd_ratio: float = 1./16):
156
+ super().__init__(n_classes, n_input_channels, config='152_cifar', input_dimension=input_dimension,
157
+ final_layer_dropout=final_layer_dropout, stochastic_depth_p=stochastic_depth_p,
158
+ squeeze_excitation=squeeze_excitation, squeeze_excitation_rd_ratio=squeeze_excitation_rd_ratio)
159
+
160
+ class ResNet50bn_CIFAR(ResNetD):
161
+ def __init__(self, n_classes: int, n_input_channels: int = 3, input_dimension: int = 2,
162
+ final_layer_dropout: float = 0.0, stochastic_depth_p: float = 0.0, squeeze_excitation: bool = False,
163
+ squeeze_excitation_rd_ratio: float = 1./16):
164
+ super().__init__(n_classes, n_input_channels, config='50_cifar_bn', input_dimension=input_dimension,
165
+ final_layer_dropout=final_layer_dropout, stochastic_depth_p=stochastic_depth_p,
166
+ squeeze_excitation=squeeze_excitation, squeeze_excitation_rd_ratio=squeeze_excitation_rd_ratio)
167
+
168
+ class ResNet152bn_CIFAR(ResNetD):
169
+ def __init__(self, n_classes: int, n_input_channels: int = 3, input_dimension: int = 2,
170
+ final_layer_dropout: float = 0.0, stochastic_depth_p: float = 0.0, squeeze_excitation: bool = False,
171
+ squeeze_excitation_rd_ratio: float = 1./16):
172
+ super().__init__(n_classes, n_input_channels, config='152_cifar_bn', input_dimension=input_dimension,
173
+ final_layer_dropout=final_layer_dropout, stochastic_depth_p=stochastic_depth_p,
174
+ squeeze_excitation=squeeze_excitation, squeeze_excitation_rd_ratio=squeeze_excitation_rd_ratio)
175
+
176
+ class ResNet18(ResNetD):
177
+ def __init__(self, n_classes: int, n_input_channels: int = 3, input_dimension: int = 2,
178
+ final_layer_dropout: float = 0.0, stochastic_depth_p: float = 0.0, squeeze_excitation: bool = False,
179
+ squeeze_excitation_rd_ratio: float = 1./16):
180
+ super().__init__(n_classes, n_input_channels, config='18', input_dimension=input_dimension,
181
+ final_layer_dropout=final_layer_dropout, stochastic_depth_p=stochastic_depth_p,
182
+ squeeze_excitation=squeeze_excitation, squeeze_excitation_rd_ratio=squeeze_excitation_rd_ratio)
183
+
184
+ class ResNet34(ResNetD):
185
+ def __init__(self, n_classes: int, n_input_channels: int = 3, input_dimension: int = 2,
186
+ final_layer_dropout: float = 0.0, stochastic_depth_p: float = 0.0, squeeze_excitation: bool = False,
187
+ squeeze_excitation_rd_ratio: float = 1./16):
188
+ super().__init__(n_classes, n_input_channels, config='34', input_dimension=input_dimension,
189
+ final_layer_dropout=final_layer_dropout, stochastic_depth_p=stochastic_depth_p,
190
+ squeeze_excitation=squeeze_excitation, squeeze_excitation_rd_ratio=squeeze_excitation_rd_ratio)
191
+
192
+ class ResNet50(ResNetD):
193
+ def __init__(self, n_classes: int, n_input_channels: int = 3, input_dimension: int = 2,
194
+ final_layer_dropout: float = 0.0, stochastic_depth_p: float = 0.0, squeeze_excitation: bool = False,
195
+ squeeze_excitation_rd_ratio: float = 1./16):
196
+ super().__init__(n_classes, n_input_channels, config='50', input_dimension=input_dimension,
197
+ final_layer_dropout=final_layer_dropout, stochastic_depth_p=stochastic_depth_p,
198
+ squeeze_excitation=squeeze_excitation, squeeze_excitation_rd_ratio=squeeze_excitation_rd_ratio)
199
+
200
+ class ResNet152(ResNetD):
201
+ def __init__(self, n_classes: int, n_input_channels: int = 3, input_dimension: int = 2,
202
+ final_layer_dropout: float = 0.0, stochastic_depth_p: float = 0.0, squeeze_excitation: bool = False,
203
+ squeeze_excitation_rd_ratio: float = 1./16):
204
+ super().__init__(n_classes, n_input_channels, config='152', input_dimension=input_dimension,
205
+ final_layer_dropout=final_layer_dropout, stochastic_depth_p=stochastic_depth_p,
206
+ squeeze_excitation=squeeze_excitation, squeeze_excitation_rd_ratio=squeeze_excitation_rd_ratio)
207
+
208
+ class ResNet50bn(ResNetD):
209
+ def __init__(self, n_classes: int, n_input_channels: int = 3, input_dimension: int = 2,
210
+ final_layer_dropout: float = 0.0, stochastic_depth_p: float = 0.0, squeeze_excitation: bool = False,
211
+ squeeze_excitation_rd_ratio: float = 1./16):
212
+ super().__init__(n_classes, n_input_channels, config='50_bn', input_dimension=input_dimension,
213
+ final_layer_dropout=final_layer_dropout, stochastic_depth_p=stochastic_depth_p,
214
+ squeeze_excitation=squeeze_excitation, squeeze_excitation_rd_ratio=squeeze_excitation_rd_ratio)
215
+
216
+ class ResNet152bn(ResNetD):
217
+ def __init__(self, n_classes: int, n_input_channels: int = 3, input_dimension: int = 2,
218
+ final_layer_dropout: float = 0.0, stochastic_depth_p: float = 0.0, squeeze_excitation: bool = False,
219
+ squeeze_excitation_rd_ratio: float = 1./16):
220
+ super().__init__(n_classes, n_input_channels, config='152_bn', input_dimension=input_dimension,
221
+ final_layer_dropout=final_layer_dropout, stochastic_depth_p=stochastic_depth_p,
222
+ squeeze_excitation=squeeze_excitation, squeeze_excitation_rd_ratio=squeeze_excitation_rd_ratio)
223
+
224
+
225
+ if __name__ == '__main__':
226
+ data = torch.rand((1, 3, 224, 224))
227
+
228
+ model = ResNet50bn(10, 3)
229
+ import hiddenlayer as hl
230
+
231
+ g = hl.build_graph(model, data,
232
+ transforms=None)
233
+ g.save("network_architecture.pdf")
234
+ del g
235
+
236
+ #print(model.compute_conv_feature_map_size((32, 32)))
model/dynamic-network-architectures-main/dynamic_network_architectures/architectures/unet.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, Type, List, Tuple
2
+
3
+ import torch
4
+ from dynamic_network_architectures.building_blocks.residual_encoders import ResidualEncoder
5
+ from dynamic_network_architectures.building_blocks.residual import BasicBlockD, BottleneckD
6
+ from torch import nn
7
+ from torch.nn.modules.conv import _ConvNd
8
+ from torch.nn.modules.dropout import _DropoutNd
9
+
10
+ from dynamic_network_architectures.building_blocks.plain_conv_encoder import PlainConvEncoder
11
+ from dynamic_network_architectures.building_blocks.unet_decoder import UNetDecoder, UNetDecoder_Seg
12
+ from dynamic_network_architectures.building_blocks.helper import convert_conv_op_to_dim
13
+
14
+
15
+ class PlainConvUNet(nn.Module):
16
+ def __init__(self,
17
+ input_channels: int,
18
+ n_stages: int,
19
+ features_per_stage: Union[int, List[int], Tuple[int, ...]],
20
+ conv_op: Type[_ConvNd],
21
+ kernel_sizes: Union[int, List[int], Tuple[int, ...]],
22
+ strides: Union[int, List[int], Tuple[int, ...]],
23
+ n_conv_per_stage: Union[int, List[int], Tuple[int, ...]],
24
+ n_conv_per_stage_decoder: Union[int, Tuple[int, ...], List[int]],
25
+ conv_bias: bool = False,
26
+ norm_op: Union[None, Type[nn.Module]] = None,
27
+ norm_op_kwargs: dict = None,
28
+ dropout_op: Union[None, Type[_DropoutNd]] = None,
29
+ dropout_op_kwargs: dict = None,
30
+ nonlin: Union[None, Type[torch.nn.Module]] = None, # activation
31
+ nonlin_kwargs: dict = None,
32
+ deep_supervision: bool = False,
33
+ nonlin_first: bool = False
34
+ ):
35
+ """
36
+ nonlin_first: if True you get conv -> nonlin -> norm. Else it's conv -> norm -> nonlin
37
+ """
38
+ super().__init__()
39
+ if isinstance(n_conv_per_stage, int):
40
+ n_conv_per_stage = [n_conv_per_stage] * n_stages
41
+ if isinstance(n_conv_per_stage_decoder, int):
42
+ n_conv_per_stage_decoder = [n_conv_per_stage_decoder] * (n_stages - 1)
43
+ assert len(n_conv_per_stage) == n_stages, "n_conv_per_stage must have as many entries as we have " \
44
+ f"resolution stages. here: {n_stages}. " \
45
+ f"n_conv_per_stage: {n_conv_per_stage}"
46
+ assert len(n_conv_per_stage_decoder) == (n_stages - 1), "n_conv_per_stage_decoder must have one less entries " \
47
+ f"as we have resolution stages. here: {n_stages} " \
48
+ f"stages, so it should have {n_stages - 1} entries. " \
49
+ f"n_conv_per_stage_decoder: {n_conv_per_stage_decoder}"
50
+ self.encoder = PlainConvEncoder(input_channels, n_stages, features_per_stage, conv_op, kernel_sizes, strides,
51
+ n_conv_per_stage, conv_bias, norm_op, norm_op_kwargs, dropout_op,
52
+ dropout_op_kwargs, nonlin, nonlin_kwargs, return_skips=True,
53
+ nonlin_first=nonlin_first)
54
+
55
+ self.decoder = UNetDecoder(self.encoder, n_conv_per_stage_decoder, deep_supervision,
56
+ nonlin_first=nonlin_first)
57
+
58
+ def forward(self, x):
59
+ skips = self.encoder(x) # [2, 32, 256, 256, 96] ... [2, 768, 8, 8, 3]
60
+ outs = self.decoder(skips) # [2, 32, 256, 256, 96] ... [2, 512, 16, 16, 6]
61
+ return skips, outs # latent_embeddings(a list of multiscale features), perpixel_embeddings(a list of decoder outputs)
62
+
63
+ def compute_conv_feature_map_size(self, input_size):
64
+ assert len(input_size) == convert_conv_op_to_dim(self.encoder.conv_op), "just give the image size without color/feature channels or " \
65
+ "batch channel. Do not give input_size=(b, c, x, y(, z)). " \
66
+ "Give input_size=(x, y(, z))!"
67
+ return self.encoder.compute_conv_feature_map_size(input_size) + self.decoder.compute_conv_feature_map_size(input_size)
68
+
69
+
70
+ class PlainConvUNet_Seg(nn.Module):
71
+ def __init__(self,
72
+ input_channels: int,
73
+ n_stages: int,
74
+ features_per_stage: Union[int, List[int], Tuple[int, ...]],
75
+ conv_op: Type[_ConvNd],
76
+ kernel_sizes: Union[int, List[int], Tuple[int, ...]],
77
+ strides: Union[int, List[int], Tuple[int, ...]],
78
+ n_conv_per_stage: Union[int, List[int], Tuple[int, ...]],
79
+ num_classes: int,
80
+ n_conv_per_stage_decoder: Union[int, Tuple[int, ...], List[int]],
81
+ conv_bias: bool = False,
82
+ norm_op: Union[None, Type[nn.Module]] = None,
83
+ norm_op_kwargs: dict = None,
84
+ dropout_op: Union[None, Type[_DropoutNd]] = None,
85
+ dropout_op_kwargs: dict = None,
86
+ nonlin: Union[None, Type[torch.nn.Module]] = None, # activation
87
+ nonlin_kwargs: dict = None,
88
+ deep_supervision: bool = False,
89
+ nonlin_first: bool = False
90
+ ):
91
+ """
92
+ nonlin_first: if True you get conv -> nonlin -> norm. Else it's conv -> norm -> nonlin
93
+ """
94
+ super().__init__()
95
+ if isinstance(n_conv_per_stage, int):
96
+ n_conv_per_stage = [n_conv_per_stage] * n_stages
97
+ if isinstance(n_conv_per_stage_decoder, int):
98
+ n_conv_per_stage_decoder = [n_conv_per_stage_decoder] * (n_stages - 1)
99
+ assert len(n_conv_per_stage) == n_stages, "n_conv_per_stage must have as many entries as we have " \
100
+ f"resolution stages. here: {n_stages}. " \
101
+ f"n_conv_per_stage: {n_conv_per_stage}"
102
+ assert len(n_conv_per_stage_decoder) == (n_stages - 1), "n_conv_per_stage_decoder must have one less entries " \
103
+ f"as we have resolution stages. here: {n_stages} " \
104
+ f"stages, so it should have {n_stages - 1} entries. " \
105
+ f"n_conv_per_stage_decoder: {n_conv_per_stage_decoder}"
106
+ self.encoder = PlainConvEncoder(input_channels, n_stages, features_per_stage, conv_op, kernel_sizes, strides,
107
+ n_conv_per_stage, conv_bias, norm_op, norm_op_kwargs, dropout_op,
108
+ dropout_op_kwargs, nonlin, nonlin_kwargs, return_skips=True,
109
+ nonlin_first=nonlin_first)
110
+ self.decoder = UNetDecoder_Seg(self.encoder, num_classes, n_conv_per_stage_decoder, deep_supervision,
111
+ nonlin_first=nonlin_first)
112
+
113
+ def forward(self, x):
114
+ skips = self.encoder(x) # [2, 32, 256, 256, 96] ... [2, 768, 8, 8, 3]
115
+ out = self.decoder(skips) # [2, num_class, 256, 256, 96]
116
+ return out
117
+
118
+ def compute_conv_feature_map_size(self, input_size):
119
+ assert len(input_size) == convert_conv_op_to_dim(self.encoder.conv_op), "just give the image size without color/feature channels or " \
120
+ "batch channel. Do not give input_size=(b, c, x, y(, z)). " \
121
+ "Give input_size=(x, y(, z))!"
122
+ return self.encoder.compute_conv_feature_map_size(input_size) + self.decoder.compute_conv_feature_map_size(input_size)
123
+
124
+
125
+ class ResidualEncoderUNet(nn.Module):
126
+ def __init__(self,
127
+ input_channels: int,
128
+ n_stages: int,
129
+ features_per_stage: Union[int, List[int], Tuple[int, ...]],
130
+ conv_op: Type[_ConvNd],
131
+ kernel_sizes: Union[int, List[int], Tuple[int, ...]],
132
+ strides: Union[int, List[int], Tuple[int, ...]],
133
+ n_blocks_per_stage: Union[int, List[int], Tuple[int, ...]],
134
+ n_conv_per_stage_decoder: Union[int, Tuple[int, ...], List[int]],
135
+ conv_bias: bool = False,
136
+ norm_op: Union[None, Type[nn.Module]] = None,
137
+ norm_op_kwargs: dict = None,
138
+ dropout_op: Union[None, Type[_DropoutNd]] = None,
139
+ dropout_op_kwargs: dict = None,
140
+ nonlin: Union[None, Type[torch.nn.Module]] = None,
141
+ nonlin_kwargs: dict = None,
142
+ deep_supervision: bool = False,
143
+ block: Union[Type[BasicBlockD], Type[BottleneckD]] = BasicBlockD,
144
+ bottleneck_channels: Union[int, List[int], Tuple[int, ...]] = None,
145
+ stem_channels: int = None
146
+ ):
147
+ super().__init__()
148
+ if isinstance(n_blocks_per_stage, int):
149
+ n_blocks_per_stage = [n_blocks_per_stage] * n_stages
150
+ if isinstance(n_conv_per_stage_decoder, int):
151
+ n_conv_per_stage_decoder = [n_conv_per_stage_decoder] * (n_stages - 1)
152
+ assert len(n_blocks_per_stage) == n_stages, "n_blocks_per_stage must have as many entries as we have " \
153
+ f"resolution stages. here: {n_stages}. " \
154
+ f"n_blocks_per_stage: {n_blocks_per_stage}"
155
+ assert len(n_conv_per_stage_decoder) == (n_stages - 1), "n_conv_per_stage_decoder must have one less entries " \
156
+ f"as we have resolution stages. here: {n_stages} " \
157
+ f"stages, so it should have {n_stages - 1} entries. " \
158
+ f"n_conv_per_stage_decoder: {n_conv_per_stage_decoder}"
159
+ self.encoder = ResidualEncoder(input_channels, n_stages, features_per_stage, conv_op, kernel_sizes, strides,
160
+ n_blocks_per_stage, conv_bias, norm_op, norm_op_kwargs, dropout_op,
161
+ dropout_op_kwargs, nonlin, nonlin_kwargs, block, bottleneck_channels,
162
+ return_skips=True, disable_default_stem=False, stem_channels=stem_channels)
163
+
164
+ self.decoder = UNetDecoder(self.encoder, n_conv_per_stage_decoder, deep_supervision)
165
+
166
+ def forward(self, x):
167
+ skips = self.encoder(x) # [2, 32, 256, 256, 96] ... [2, 768, 8, 8, 3]
168
+ outs = self.decoder(skips) # [2, 32, 256, 256, 96] ... [2, 512, 16, 16, 6]
169
+ return skips, outs # latent_embeddings(a list of multiscale features), perpixel_embeddings(a list of decoder outputs)
170
+
171
+ def compute_conv_feature_map_size(self, input_size):
172
+ assert len(input_size) == convert_conv_op_to_dim(self.encoder.conv_op), "just give the image size without color/feature channels or " \
173
+ "batch channel. Do not give input_size=(b, c, x, y(, z)). " \
174
+ "Give input_size=(x, y(, z))!"
175
+ return self.encoder.compute_conv_feature_map_size(input_size) + self.decoder.compute_conv_feature_map_size(input_size)
176
+
177
+
178
+ if __name__ == '__main__':
179
+ import sys
180
+ sys.path.append('/remote-home/zihengzhao/Knowledge-Enhanced-Medical-Segmentation/medical-universal-segmentation/model/dynamic-network-architectures-main')
181
+
182
+ data = torch.rand((2, 3, 256, 256, 96)).cuda()
183
+
184
+ model = PlainConvUNet(3, 6, (32, 64, 128, 256, 512, 768), nn.Conv3d, 3, (1, 2, 2, 2, 2, 2), (2, 2, 2, 2, 2, 2), 4,
185
+ (2, 2, 2, 2, 2), False, nn.BatchNorm3d, None, None, None, nn.ReLU, deep_supervision=True).cuda()
186
+
187
+ dec_outs, enc_outs = model(data)
188
+ print('DEC')
189
+ for i in dec_outs:
190
+ print(i.shape) # (2, 4, 256, 256, 96)
191
+ print('ENC')
192
+ for i in enc_outs:
193
+ print(i.shape) # ()
194
+ exit()
195
+
196
+
197
+ if False:
198
+ import hiddenlayer as hl
199
+
200
+ g = hl.build_graph(model, data,
201
+ transforms=None)
202
+ g.save("network_architecture.pdf")
203
+ del g
204
+
205
+ print(model.compute_conv_feature_map_size(data.shape[2:]))
206
+
207
+ data = torch.rand((1, 4, 512, 512))
208
+
209
+ model = PlainConvUNet(4, 8, (32, 64, 125, 256, 512, 512, 512, 512), nn.Conv2d, 3, (1, 2, 2, 2, 2, 2, 2, 2), (2, 2, 2, 2, 2, 2, 2, 2), 4,
210
+ (2, 2, 2, 2, 2, 2, 2), False, nn.BatchNorm2d, None, None, None, nn.ReLU, deep_supervision=True)
211
+
212
+ if False:
213
+ import hiddenlayer as hl
214
+
215
+ g = hl.build_graph(model, data,
216
+ transforms=None)
217
+ g.save("network_architecture.pdf")
218
+ del g
219
+
220
+ print(model.compute_conv_feature_map_size(data.shape[2:]))
model/dynamic-network-architectures-main/dynamic_network_architectures/architectures/vgg.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ from dynamic_network_architectures.building_blocks.plain_conv_encoder import PlainConvEncoder
5
+ from dynamic_network_architectures.building_blocks.helper import get_matching_pool_op, get_default_network_config
6
+
7
+ _VGG_CONFIGS = {
8
+ '16': {'features_per_stage': (64, 128, 256, 512, 512, 512), 'n_conv_per_stage': (2, 2, 2, 3, 3, 3),
9
+ 'strides': (1, 2, 2, 2, 2, 2)},
10
+ '19': {'features_per_stage': (64, 128, 256, 512, 512, 512), 'n_conv_per_stage': (2, 2, 3, 3, 4, 4),
11
+ 'strides': (1, 2, 2, 2, 2, 2)},
12
+ '16_cifar': {'features_per_stage': (64, 128, 256, 512), 'n_conv_per_stage': (2, 3, 5, 5), 'strides': (1, 2, 2, 2)},
13
+ '19_cifar': {'features_per_stage': (64, 128, 256, 512), 'n_conv_per_stage': (3, 4, 5, 6), 'strides': (1, 2, 2, 2)},
14
+ }
15
+
16
+ _VGG_OPS = {
17
+ 1: {'conv_op': nn.Conv1d, 'norm_op': nn.BatchNorm1d},
18
+ 2: {'conv_op': nn.Conv2d, 'norm_op': nn.BatchNorm2d},
19
+ 3: {'conv_op': nn.Conv3d, 'norm_op': nn.BatchNorm3d},
20
+ }
21
+
22
+
23
+ class VGG(nn.Module):
24
+ def __init__(self, n_classes: int, n_input_channel: int = 3, config='16', input_dimension=2):
25
+ """
26
+ This is not 1:1 VGG because it does not have the bloated fully connected layers at the end. Since these were
27
+ counted towards the XX layers as well, we increase the number of convolutional layers so that we have the
28
+ desired number of conv layers in total
29
+
30
+ We also use batchnorm
31
+ """
32
+ super().__init__()
33
+ cfg = _VGG_CONFIGS[config]
34
+ ops = get_default_network_config(dimension=input_dimension)
35
+ self.encoder = PlainConvEncoder(
36
+ n_input_channel, n_stages=len(cfg['features_per_stage']), features_per_stage=cfg['features_per_stage'],
37
+ conv_op=ops['conv_op'],
38
+ kernel_sizes=3, strides=cfg['strides'], n_conv_per_stage=cfg['n_conv_per_stage'], conv_bias=False,
39
+ norm_op=ops['norm_op'], norm_op_kwargs=None, dropout_op=None, dropout_op_kwargs=None, nonlin=nn.ReLU,
40
+ nonlin_kwargs={'inplace': True}, return_skips=False
41
+ )
42
+ self.gap = get_matching_pool_op(conv_op=ops['conv_op'], adaptive=True, pool_type='avg')(1)
43
+ self.classifier = nn.Linear(cfg['features_per_stage'][-1], n_classes, True)
44
+
45
+ def forward(self, x):
46
+ x = self.encoder(x)
47
+ x = self.gap(x).squeeze()
48
+ return self.classifier(x)
49
+
50
+ def compute_conv_feature_map_size(self, input_size):
51
+ return self.encoder.compute_conv_feature_map_size(input_size)
52
+
53
+
54
+ class VGG16(VGG):
55
+ def __init__(self, n_classes: int, n_input_channel: int = 3, input_dimension: int = 2):
56
+ super().__init__(n_classes, n_input_channel, config='16', input_dimension=input_dimension)
57
+
58
+
59
+ class VGG19(VGG):
60
+ def __init__(self, n_classes: int, n_input_channel: int = 3, input_dimension: int = 2):
61
+ super().__init__(n_classes, n_input_channel, config='19', input_dimension=input_dimension)
62
+
63
+
64
+ class VGG16_cifar(VGG):
65
+ def __init__(self, n_classes: int, n_input_channel: int = 3, input_dimension: int = 2):
66
+ super().__init__(n_classes, n_input_channel, config='16_cifar', input_dimension=input_dimension)
67
+
68
+
69
+ class VGG19_cifar(VGG):
70
+ def __init__(self, n_classes: int, n_input_channel: int = 3, input_dimension: int = 2):
71
+ super().__init__(n_classes, n_input_channel, config='19_cifar', input_dimension=input_dimension)
72
+
73
+
74
+ if __name__ == '__main__':
75
+ data = torch.rand((1, 3, 32, 32))
76
+
77
+ model = VGG19_cifar(10, 3)
78
+ import hiddenlayer as hl
79
+
80
+ g = hl.build_graph(model, data,
81
+ transforms=None)
82
+ g.save("network_architecture.pdf")
83
+ del g
84
+
85
+ print(model.compute_conv_feature_map_size((32, 32)))
model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__init__.py ADDED
File without changes
model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (272 Bytes). View file
 
model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/helper.cpython-310.pyc ADDED
Binary file (5.93 kB). View file
 
model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/plain_conv_encoder.cpython-310.pyc ADDED
Binary file (4.22 kB). View file
 
model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/regularization.cpython-310.pyc ADDED
Binary file (4.39 kB). View file
 
model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/residual.cpython-310.pyc ADDED
Binary file (14.2 kB). View file
 
model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/residual_encoders.cpython-310.pyc ADDED
Binary file (6.39 kB). View file
 
model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/simple_conv_blocks.cpython-310.pyc ADDED
Binary file (5.85 kB). View file
 
model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/unet_decoder.cpython-310.pyc ADDED
Binary file (6.85 kB). View file
 
model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/helper.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Type
2
+ import numpy as np
3
+ import torch.nn
4
+ from torch import nn
5
+ from torch.nn.modules.batchnorm import _BatchNorm
6
+ from torch.nn.modules.conv import _ConvNd, _ConvTransposeNd
7
+ from torch.nn.modules.dropout import _DropoutNd
8
+ from torch.nn.modules.instancenorm import _InstanceNorm
9
+
10
+
11
+ def convert_dim_to_conv_op(dimension: int) -> Type[_ConvNd]:
12
+ """
13
+ :param dimension: 1, 2 or 3
14
+ :return: conv Class of corresponding dimension
15
+ """
16
+ if dimension == 1:
17
+ return nn.Conv1d
18
+ elif dimension == 2:
19
+ return nn.Conv2d
20
+ elif dimension == 3:
21
+ return nn.Conv3d
22
+ else:
23
+ raise ValueError("Unknown dimension. Only 1, 2 and 3 are supported")
24
+
25
+
26
+ def convert_conv_op_to_dim(conv_op: Type[_ConvNd]) -> int:
27
+ """
28
+ :param conv_op: conv class
29
+ :return: dimension: 1, 2 or 3
30
+ """
31
+ if conv_op == nn.Conv1d:
32
+ return 1
33
+ elif conv_op == nn.Conv2d:
34
+ return 2
35
+ elif conv_op == nn.Conv3d:
36
+ return 3
37
+ else:
38
+ raise ValueError("Unknown dimension. Only 1d 2d and 3d conv are supported. got %s" % str(conv_op))
39
+
40
+
41
+ def get_matching_pool_op(conv_op: Type[_ConvNd] = None,
42
+ dimension: int = None,
43
+ adaptive=False,
44
+ pool_type: str = 'avg') -> Type[torch.nn.Module]:
45
+ """
46
+ You MUST set EITHER conv_op OR dimension. Do not set both!
47
+ :param conv_op:
48
+ :param dimension:
49
+ :param adaptive:
50
+ :param pool_type: either 'avg' or 'max'
51
+ :return:
52
+ """
53
+ assert not ((conv_op is not None) and (dimension is not None)), \
54
+ "You MUST set EITHER conv_op OR dimension. Do not set both!"
55
+ assert pool_type in ['avg', 'max'], 'pool_type must be either avg or max'
56
+ if conv_op is not None:
57
+ dimension = convert_conv_op_to_dim(conv_op)
58
+ assert dimension in [1, 2, 3], 'Dimension must be 1, 2 or 3'
59
+
60
+ if conv_op is not None:
61
+ dimension = convert_conv_op_to_dim(conv_op)
62
+
63
+ if dimension == 1:
64
+ if pool_type == 'avg':
65
+ if adaptive:
66
+ return nn.AdaptiveAvgPool1d
67
+ else:
68
+ return nn.AvgPool1d
69
+ elif pool_type == 'max':
70
+ if adaptive:
71
+ return nn.AdaptiveMaxPool1d
72
+ else:
73
+ return nn.MaxPool1d
74
+ elif dimension == 2:
75
+ if pool_type == 'avg':
76
+ if adaptive:
77
+ return nn.AdaptiveAvgPool2d
78
+ else:
79
+ return nn.AvgPool2d
80
+ elif pool_type == 'max':
81
+ if adaptive:
82
+ return nn.AdaptiveMaxPool2d
83
+ else:
84
+ return nn.MaxPool2d
85
+ elif dimension == 3:
86
+ if pool_type == 'avg':
87
+ if adaptive:
88
+ return nn.AdaptiveAvgPool3d
89
+ else:
90
+ return nn.AvgPool3d
91
+ elif pool_type == 'max':
92
+ if adaptive:
93
+ return nn.AdaptiveMaxPool3d
94
+ else:
95
+ return nn.MaxPool3d
96
+
97
+
98
+ def get_matching_instancenorm(conv_op: Type[_ConvNd] = None, dimension: int = None) -> Type[_InstanceNorm]:
99
+ """
100
+ You MUST set EITHER conv_op OR dimension. Do not set both!
101
+
102
+ :param conv_op:
103
+ :param dimension:
104
+ :return:
105
+ """
106
+ assert not ((conv_op is not None) and (dimension is not None)), \
107
+ "You MUST set EITHER conv_op OR dimension. Do not set both!"
108
+ if conv_op is not None:
109
+ dimension = convert_conv_op_to_dim(conv_op)
110
+ if dimension is not None:
111
+ assert dimension in [1, 2, 3], 'Dimension must be 1, 2 or 3'
112
+ if dimension == 1:
113
+ return nn.InstanceNorm1d
114
+ elif dimension == 2:
115
+ return nn.InstanceNorm2d
116
+ elif dimension == 3:
117
+ return nn.InstanceNorm3d
118
+
119
+
120
+ def get_matching_convtransp(conv_op: Type[_ConvNd] = None, dimension: int = None) -> Type[_ConvTransposeNd]:
121
+ """
122
+ You MUST set EITHER conv_op OR dimension. Do not set both!
123
+
124
+ :param conv_op:
125
+ :param dimension:
126
+ :return:
127
+ """
128
+ assert not ((conv_op is not None) and (dimension is not None)), \
129
+ "You MUST set EITHER conv_op OR dimension. Do not set both!"
130
+ if conv_op is not None:
131
+ dimension = convert_conv_op_to_dim(conv_op)
132
+ assert dimension in [1, 2, 3], 'Dimension must be 1, 2 or 3'
133
+ if dimension == 1:
134
+ return nn.ConvTranspose1d
135
+ elif dimension == 2:
136
+ return nn.ConvTranspose2d
137
+ elif dimension == 3:
138
+ return nn.ConvTranspose3d
139
+
140
+
141
+ def get_matching_batchnorm(conv_op: Type[_ConvNd] = None, dimension: int = None) -> Type[_BatchNorm]:
142
+ """
143
+ You MUST set EITHER conv_op OR dimension. Do not set both!
144
+
145
+ :param conv_op:
146
+ :param dimension:
147
+ :return:
148
+ """
149
+ assert not ((conv_op is not None) and (dimension is not None)), \
150
+ "You MUST set EITHER conv_op OR dimension. Do not set both!"
151
+ if conv_op is not None:
152
+ dimension = convert_conv_op_to_dim(conv_op)
153
+ assert dimension in [1, 2, 3], 'Dimension must be 1, 2 or 3'
154
+ if dimension == 1:
155
+ return nn.BatchNorm1d
156
+ elif dimension == 2:
157
+ return nn.BatchNorm2d
158
+ elif dimension == 3:
159
+ return nn.BatchNorm3d
160
+
161
+
162
+ def get_matching_dropout(conv_op: Type[_ConvNd] = None, dimension: int = None) -> Type[_DropoutNd]:
163
+ """
164
+ You MUST set EITHER conv_op OR dimension. Do not set both!
165
+
166
+ :param conv_op:
167
+ :param dimension:
168
+ :return:
169
+ """
170
+ assert not ((conv_op is not None) and (dimension is not None)), \
171
+ "You MUST set EITHER conv_op OR dimension. Do not set both!"
172
+ assert dimension in [1, 2, 3], 'Dimension must be 1, 2 or 3'
173
+ if dimension == 1:
174
+ return nn.Dropout
175
+ elif dimension == 2:
176
+ return nn.Dropout2d
177
+ elif dimension == 3:
178
+ return nn.Dropout3d
179
+
180
+
181
+ def maybe_convert_scalar_to_list(conv_op, scalar):
182
+ """
183
+ useful for converting, for example, kernel_size=3 to [3, 3, 3] in case of nn.Conv3d
184
+ :param conv_op:
185
+ :param scalar:
186
+ :return:
187
+ """
188
+ if not isinstance(scalar, (tuple, list, np.ndarray)):
189
+ if conv_op == nn.Conv2d:
190
+ return [scalar] * 2
191
+ elif conv_op == nn.Conv3d:
192
+ return [scalar] * 3
193
+ elif conv_op == nn.Conv1d:
194
+ return [scalar] * 1
195
+ else:
196
+ raise RuntimeError("Invalid conv op: %s" % str(conv_op))
197
+ else:
198
+ return scalar
199
+
200
+
201
+ def get_default_network_config(dimension: int = 2,
202
+ nonlin: str = "ReLU",
203
+ norm_type: str = "bn") -> dict:
204
+ """
205
+ Use this to get a standard configuration. A network configuration looks like this:
206
+
207
+ config = {'conv_op': torch.nn.modules.conv.Conv2d,
208
+ 'dropout_op': torch.nn.modules.dropout.Dropout2d,
209
+ 'norm_op': torch.nn.modules.batchnorm.BatchNorm2d,
210
+ 'norm_op_kwargs': {'eps': 1e-05, 'affine': True},
211
+ 'nonlin': torch.nn.modules.activation.ReLU,
212
+ 'nonlin_kwargs': {'inplace': True}}
213
+
214
+ There is no need to use get_default_network_config. You can create your own. Network configs are a convenient way of
215
+ setting dimensionality, normalization and nonlinearity.
216
+
217
+ :param dimension: integer denoting the dimension of the data. 1, 2 and 3 are accepted
218
+ :param nonlin: string (ReLU or LeakyReLU)
219
+ :param norm_type: string (bn=batch norm, in=instance norm)
220
+ torch.nn.Module
221
+ :return: dict
222
+ """
223
+ config = {}
224
+ config['conv_op'] = convert_dim_to_conv_op(dimension)
225
+ config['dropout_op'] = get_matching_dropout(dimension=dimension)
226
+ if norm_type == "bn":
227
+ config['norm_op'] = get_matching_batchnorm(dimension=dimension)
228
+ elif norm_type == "in":
229
+ config['norm_op'] = get_matching_instancenorm(dimension=dimension)
230
+
231
+ config['norm_op_kwargs'] = None # this will use defaults
232
+
233
+ if nonlin == "LeakyReLU":
234
+ config['nonlin'] = nn.LeakyReLU
235
+ config['nonlin_kwargs'] = {'negative_slope': 1e-2, 'inplace': True}
236
+ elif nonlin == "ReLU":
237
+ config['nonlin'] = nn.ReLU
238
+ config['nonlin_kwargs'] = {'inplace': True}
239
+ else:
240
+ raise NotImplementedError('Unknown nonlin %s. Only "LeakyReLU" and "ReLU" are supported for now' % nonlin)
241
+
242
+ return config
model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/plain_conv_encoder.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import numpy as np
4
+ from typing import Union, Type, List, Tuple
5
+
6
+ from torch.nn.modules.conv import _ConvNd
7
+ from torch.nn.modules.dropout import _DropoutNd
8
+ from dynamic_network_architectures.building_blocks.simple_conv_blocks import StackedConvBlocks
9
+ from dynamic_network_architectures.building_blocks.helper import maybe_convert_scalar_to_list, get_matching_pool_op
10
+
11
+
12
+ class PlainConvEncoder(nn.Module):
13
+ def __init__(self,
14
+ input_channels: int,
15
+ n_stages: int,
16
+ features_per_stage: Union[int, List[int], Tuple[int, ...]],
17
+ conv_op: Type[_ConvNd],
18
+ kernel_sizes: Union[int, List[int], Tuple[int, ...]],
19
+ strides: Union[int, List[int], Tuple[int, ...]],
20
+ n_conv_per_stage: Union[int, List[int], Tuple[int, ...]],
21
+ conv_bias: bool = False,
22
+ norm_op: Union[None, Type[nn.Module]] = None,
23
+ norm_op_kwargs: dict = None,
24
+ dropout_op: Union[None, Type[_DropoutNd]] = None,
25
+ dropout_op_kwargs: dict = None,
26
+ nonlin: Union[None, Type[torch.nn.Module]] = None,
27
+ nonlin_kwargs: dict = None,
28
+ return_skips: bool = False,
29
+ nonlin_first: bool = False,
30
+ pool: str = 'conv'
31
+ ):
32
+
33
+ super().__init__()
34
+ if isinstance(kernel_sizes, int):
35
+ kernel_sizes = [kernel_sizes] * n_stages
36
+ if isinstance(features_per_stage, int):
37
+ features_per_stage = [features_per_stage] * n_stages
38
+ if isinstance(n_conv_per_stage, int):
39
+ n_conv_per_stage = [n_conv_per_stage] * n_stages
40
+ if isinstance(strides, int):
41
+ strides = [strides] * n_stages
42
+ assert len(kernel_sizes) == n_stages, "kernel_sizes must have as many entries as we have resolution stages (n_stages)"
43
+ assert len(n_conv_per_stage) == n_stages, "n_conv_per_stage must have as many entries as we have resolution stages (n_stages)"
44
+ assert len(features_per_stage) == n_stages, "features_per_stage must have as many entries as we have resolution stages (n_stages)"
45
+ assert len(strides) == n_stages, "strides must have as many entries as we have resolution stages (n_stages). " \
46
+ "Important: first entry is recommended to be 1, else we run strided conv drectly on the input"
47
+
48
+ stages = []
49
+ for s in range(n_stages):
50
+ stage_modules = []
51
+ if pool == 'max' or pool == 'avg':
52
+ if (isinstance(strides[s], int) and strides[s] != 1) or \
53
+ isinstance(strides[s], (tuple, list)) and any([i != 1 for i in strides[s]]):
54
+ stage_modules.append(get_matching_pool_op(conv_op, pool_type=pool)(kernel_size=strides[s], stride=strides[s]))
55
+ conv_stride = 1
56
+ elif pool == 'conv':
57
+ conv_stride = strides[s]
58
+ else:
59
+ raise RuntimeError()
60
+ stage_modules.append(StackedConvBlocks(
61
+ n_conv_per_stage[s], conv_op, input_channels, features_per_stage[s], kernel_sizes[s], conv_stride,
62
+ conv_bias, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs, nonlin_first
63
+ ))
64
+ stages.append(nn.Sequential(*stage_modules))
65
+ input_channels = features_per_stage[s]
66
+
67
+ self.stages = nn.Sequential(*stages)
68
+ self.output_channels = features_per_stage
69
+ self.strides = [maybe_convert_scalar_to_list(conv_op, i) for i in strides]
70
+ self.return_skips = return_skips
71
+
72
+ # we store some things that a potential decoder needs
73
+ self.conv_op = conv_op
74
+ self.norm_op = norm_op
75
+ self.norm_op_kwargs = norm_op_kwargs
76
+ self.nonlin = nonlin
77
+ self.nonlin_kwargs = nonlin_kwargs
78
+ self.dropout_op = dropout_op
79
+ self.dropout_op_kwargs = dropout_op_kwargs
80
+ self.conv_bias = conv_bias
81
+ self.kernel_sizes = kernel_sizes
82
+
83
+ def forward(self, x):
84
+ ret = []
85
+ for s in self.stages:
86
+ x = s(x)
87
+ ret.append(x)
88
+ if self.return_skips:
89
+ return ret
90
+ else:
91
+ return ret[-1]
92
+
93
+ def compute_conv_feature_map_size(self, input_size):
94
+ output = np.int64(0)
95
+ for s in range(len(self.stages)):
96
+ if isinstance(self.stages[s], nn.Sequential):
97
+ for sq in self.stages[s]:
98
+ if hasattr(sq, 'compute_conv_feature_map_size'):
99
+ output += self.stages[s][-1].compute_conv_feature_map_size(input_size)
100
+ else:
101
+ output += self.stages[s].compute_conv_feature_map_size(input_size)
102
+ input_size = [i // j for i, j in zip(input_size, self.strides[s])]
103
+ return output
104
+
105
+
model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/regularization.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+
4
+ def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
5
+ """
6
+ This function is taken from the timm package (https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py).
7
+
8
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
9
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
10
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
11
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
12
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
13
+ 'survival rate' as the argument.
14
+ """
15
+ if drop_prob == 0. or not training:
16
+ return x
17
+ keep_prob = 1 - drop_prob
18
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
19
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
20
+ if keep_prob > 0.0 and scale_by_keep:
21
+ random_tensor.div_(keep_prob)
22
+ return x * random_tensor
23
+
24
+
25
+ class DropPath(nn.Module):
26
+ """
27
+ This class is taken from the timm package (https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py).
28
+
29
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
30
+ """
31
+ def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
32
+ super(DropPath, self).__init__()
33
+ self.drop_prob = drop_prob
34
+ self.scale_by_keep = scale_by_keep
35
+
36
+ def forward(self, x):
37
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
38
+
39
+
40
+ class SqueezeExcite(nn.Module):
41
+ """
42
+ This class is taken from the timm package (https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/squeeze_excite.py)
43
+ and slightly modified so that the convolution type can be adapted.
44
+
45
+ SE Module as defined in original SE-Nets with a few additions
46
+ Additions include:
47
+ * divisor can be specified to keep channels % div == 0 (default: 8)
48
+ * reduction channels can be specified directly by arg (if rd_channels is set)
49
+ * reduction channels can be specified by float rd_ratio (default: 1/16)
50
+ * global max pooling can be added to the squeeze aggregation
51
+ * customizable activation, normalization, and gate layer
52
+ """
53
+ def __init__(
54
+ self, channels, conv_op, rd_ratio=1. / 16, rd_channels=None, rd_divisor=8, add_maxpool=False,
55
+ act_layer=nn.ReLU, norm_layer=None, gate_layer=nn.Sigmoid):
56
+ super(SqueezeExcite, self).__init__()
57
+ self.add_maxpool = add_maxpool
58
+ if not rd_channels:
59
+ rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
60
+ self.fc1 = conv_op(channels, rd_channels, kernel_size=1, bias=True)
61
+ self.bn = norm_layer(rd_channels) if norm_layer else nn.Identity()
62
+ self.act = act_layer(inplace=True)
63
+ self.fc2 = conv_op(rd_channels, channels, kernel_size=1, bias=True)
64
+ self.gate = gate_layer()
65
+
66
+ def forward(self, x):
67
+ x_se = x.mean((2, 3), keepdim=True)
68
+ if self.add_maxpool:
69
+ # experimental codepath, may remove or change
70
+ x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True)
71
+ x_se = self.fc1(x_se)
72
+ x_se = self.act(self.bn(x_se))
73
+ x_se = self.fc2(x_se)
74
+ return x * self.gate(x_se)
75
+
76
+
77
+ def make_divisible(v, divisor=8, min_value=None, round_limit=.9):
78
+ """
79
+ This function is taken from the timm package (https://github.com/rwightman/pytorch-image-models/blob/b7cb8d0337b3e7b50516849805ddb9be5fc11644/timm/models/layers/helpers.py#L25)
80
+ """
81
+ min_value = min_value or divisor
82
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
83
+ # Make sure that round down does not go down by more than 10%.
84
+ if new_v < round_limit * v:
85
+ new_v += divisor
86
+ return new_v
model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/residual.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, List, Union, Type
2
+ import torch.nn
3
+ from torch import nn
4
+ from torch.nn.modules.conv import _ConvNd
5
+ from torch.nn.modules.dropout import _DropoutNd
6
+
7
+ from dynamic_network_architectures.building_blocks.helper import maybe_convert_scalar_to_list, get_matching_pool_op
8
+ from dynamic_network_architectures.building_blocks.simple_conv_blocks import ConvDropoutNormReLU
9
+ from dynamic_network_architectures.building_blocks.regularization import DropPath, SqueezeExcite
10
+ import numpy as np
11
+
12
+
13
+ class BasicBlockD(nn.Module):
14
+ def __init__(self,
15
+ conv_op: Type[_ConvNd],
16
+ input_channels: int,
17
+ output_channels: int,
18
+ kernel_size: Union[int, List[int], Tuple[int, ...]],
19
+ stride: Union[int, List[int], Tuple[int, ...]],
20
+ conv_bias: bool = False,
21
+ norm_op: Union[None, Type[nn.Module]] = None,
22
+ norm_op_kwargs: dict = None,
23
+ dropout_op: Union[None, Type[_DropoutNd]] = None,
24
+ dropout_op_kwargs: dict = None,
25
+ nonlin: Union[None, Type[torch.nn.Module]] = None,
26
+ nonlin_kwargs: dict = None,
27
+ stochastic_depth_p: float = 0.0,
28
+ squeeze_excitation: bool = False,
29
+ squeeze_excitation_reduction_ratio: float = 1. / 16,
30
+ # todo wideresnet?
31
+ ):
32
+ """
33
+ This implementation follows ResNet-D:
34
+
35
+ He, Tong, et al. "Bag of tricks for image classification with convolutional neural networks."
36
+ Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2019.
37
+
38
+ The skip has an avgpool (if needed) followed by 1x1 conv instead of just a strided 1x1 conv
39
+
40
+ :param conv_op:
41
+ :param input_channels:
42
+ :param output_channels:
43
+ :param kernel_size: refers only to convs in feature extraction path, not to 1x1x1 conv in skip
44
+ :param stride: only applies to first conv (and skip). Second conv always has stride 1
45
+ :param conv_bias:
46
+ :param norm_op:
47
+ :param norm_op_kwargs:
48
+ :param dropout_op: only the first conv can have dropout. The second never has
49
+ :param dropout_op_kwargs:
50
+ :param nonlin:
51
+ :param nonlin_kwargs:
52
+ :param stochastic_depth_p:
53
+ :param squeeze_excitation:
54
+ :param squeeze_excitation_reduction_ratio:
55
+ """
56
+ super().__init__()
57
+ self.input_channels = input_channels
58
+ self.output_channels = output_channels
59
+ stride = maybe_convert_scalar_to_list(conv_op, stride)
60
+ self.stride = stride
61
+
62
+ kernel_size = maybe_convert_scalar_to_list(conv_op, kernel_size)
63
+
64
+ if norm_op_kwargs is None:
65
+ norm_op_kwargs = {}
66
+ if nonlin_kwargs is None:
67
+ nonlin_kwargs = {}
68
+
69
+ self.conv1 = ConvDropoutNormReLU(conv_op, input_channels, output_channels, kernel_size, stride, conv_bias,
70
+ norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs)
71
+ self.conv2 = ConvDropoutNormReLU(conv_op, output_channels, output_channels, kernel_size, 1, conv_bias, norm_op,
72
+ norm_op_kwargs, None, None, None, None)
73
+
74
+ self.nonlin2 = nonlin(**nonlin_kwargs) if nonlin is not None else lambda x: x
75
+
76
+ # Stochastic Depth
77
+ self.apply_stochastic_depth = False if stochastic_depth_p == 0.0 else True
78
+ if self.apply_stochastic_depth:
79
+ self.drop_path = DropPath(drop_prob=stochastic_depth_p)
80
+
81
+ # Squeeze Excitation
82
+ self.apply_se = squeeze_excitation
83
+ if self.apply_se:
84
+ self.squeeze_excitation = SqueezeExcite(self.output_channels, conv_op,
85
+ rd_ratio=squeeze_excitation_reduction_ratio, rd_divisor=8)
86
+
87
+ has_stride = (isinstance(stride, int) and stride != 1) or any([i != 1 for i in stride])
88
+ requires_projection = (input_channels != output_channels)
89
+
90
+ if has_stride or requires_projection:
91
+ ops = []
92
+ if has_stride:
93
+ ops.append(get_matching_pool_op(conv_op=conv_op, adaptive=False, pool_type='avg')(stride, stride))
94
+ if requires_projection:
95
+ ops.append(
96
+ ConvDropoutNormReLU(conv_op, input_channels, output_channels, 1, 1, False, norm_op,
97
+ norm_op_kwargs, None, None, None, None
98
+ )
99
+ )
100
+ self.skip = nn.Sequential(*ops)
101
+ else:
102
+ self.skip = lambda x: x
103
+
104
+ def forward(self, x):
105
+ residual = self.skip(x)
106
+ out = self.conv2(self.conv1(x))
107
+ if self.apply_stochastic_depth:
108
+ out = self.drop_path(out)
109
+ if self.apply_se:
110
+ out = self.squeeze_excitation(out)
111
+ out += residual
112
+ return self.nonlin2(out)
113
+
114
+ def compute_conv_feature_map_size(self, input_size):
115
+ assert len(input_size) == len(self.stride), "just give the image size without color/feature channels or " \
116
+ "batch channel. Do not give input_size=(b, c, x, y(, z)). " \
117
+ "Give input_size=(x, y(, z))!"
118
+ size_after_stride = [i // j for i, j in zip(input_size, self.stride)]
119
+ # conv1
120
+ output_size_conv1 = np.prod([self.output_channels, *size_after_stride], dtype=np.int64)
121
+ # conv2
122
+ output_size_conv2 = np.prod([self.output_channels, *size_after_stride], dtype=np.int64)
123
+ # skip conv (if applicable)
124
+ if (self.input_channels != self.output_channels) or any([i != j for i, j in zip(input_size, size_after_stride)]):
125
+ assert isinstance(self.skip, nn.Sequential)
126
+ output_size_skip = np.prod([self.output_channels, *size_after_stride], dtype=np.int64)
127
+ else:
128
+ assert not isinstance(self.skip, nn.Sequential)
129
+ output_size_skip = 0
130
+ return output_size_conv1 + output_size_conv2 + output_size_skip
131
+
132
+
133
+ class BottleneckD(nn.Module):
134
+ def __init__(self,
135
+ conv_op: Type[_ConvNd],
136
+ input_channels: int,
137
+ bottleneck_channels: int,
138
+ output_channels: int,
139
+ kernel_size: Union[int, List[int], Tuple[int, ...]],
140
+ stride: Union[int, List[int], Tuple[int, ...]],
141
+ conv_bias: bool = False,
142
+ norm_op: Union[None, Type[nn.Module]] = None,
143
+ norm_op_kwargs: dict = None,
144
+ dropout_op: Union[None, Type[_DropoutNd]] = None,
145
+ dropout_op_kwargs: dict = None,
146
+ nonlin: Union[None, Type[torch.nn.Module]] = None,
147
+ nonlin_kwargs: dict = None,
148
+ stochastic_depth_p: float = 0.0,
149
+ squeeze_excitation: bool = False,
150
+ squeeze_excitation_reduction_ratio: float = 1. / 16
151
+ ):
152
+ """
153
+ This implementation follows ResNet-D:
154
+
155
+ He, Tong, et al. "Bag of tricks for image classification with convolutional neural networks."
156
+ Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2019.
157
+
158
+ The stride sits in the 3x3 conv instead of the 1x1 conv!
159
+ The skip has an avgpool (if needed) followed by 1x1 conv instead of just a strided 1x1 conv
160
+
161
+ :param conv_op:
162
+ :param input_channels:
163
+ :param output_channels:
164
+ :param kernel_size: only affects the conv in the middle (typically 3x3). The other convs remain 1x1
165
+ :param stride: only applies to the conv in the middle (and skip). Note that this deviates from the canonical
166
+ ResNet implementation where the stride is applied to the first 1x1 conv. (This implementation follows ResNet-D)
167
+ :param conv_bias:
168
+ :param norm_op:
169
+ :param norm_op_kwargs:
170
+ :param dropout_op: only the second (kernel_size) conv can have dropout. The first and last conv (1x1(x1)) never have it
171
+ :param dropout_op_kwargs:
172
+ :param nonlin:
173
+ :param nonlin_kwargs:
174
+ :param stochastic_depth_p:
175
+ :param squeeze_excitation:
176
+ :param squeeze_excitation_reduction_ratio:
177
+ """
178
+ super().__init__()
179
+ self.input_channels = input_channels
180
+ self.output_channels = output_channels
181
+ self.bottleneck_channels = bottleneck_channels
182
+ stride = maybe_convert_scalar_to_list(conv_op, stride)
183
+ self.stride = stride
184
+
185
+ kernel_size = maybe_convert_scalar_to_list(conv_op, kernel_size)
186
+ if norm_op_kwargs is None:
187
+ norm_op_kwargs = {}
188
+ if nonlin_kwargs is None:
189
+ nonlin_kwargs = {}
190
+
191
+ self.conv1 = ConvDropoutNormReLU(conv_op, input_channels, bottleneck_channels, 1, 1, conv_bias,
192
+ norm_op, norm_op_kwargs, None, None, nonlin, nonlin_kwargs)
193
+ self.conv2 = ConvDropoutNormReLU(conv_op, bottleneck_channels, bottleneck_channels, kernel_size, stride,
194
+ conv_bias,
195
+ norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs)
196
+ self.conv3 = ConvDropoutNormReLU(conv_op, bottleneck_channels, output_channels, 1, 1, conv_bias, norm_op,
197
+ norm_op_kwargs, None, None, None, None)
198
+
199
+ self.nonlin3 = nonlin(**nonlin_kwargs) if nonlin is not None else lambda x: x
200
+
201
+ # Stochastic Depth
202
+ self.apply_stochastic_depth = False if stochastic_depth_p == 0.0 else True
203
+ if self.apply_stochastic_depth:
204
+ self.drop_path = DropPath(drop_prob=stochastic_depth_p)
205
+
206
+ # Squeeze Excitation
207
+ self.apply_se = squeeze_excitation
208
+ if self.apply_se:
209
+ self.squeeze_excitation = SqueezeExcite(self.output_channels, conv_op,
210
+ rd_ratio=squeeze_excitation_reduction_ratio, rd_divisor=8)
211
+
212
+ has_stride = (isinstance(stride, int) and stride != 1) or any([i != 1 for i in stride])
213
+ requires_projection = (input_channels != output_channels)
214
+
215
+ if has_stride or requires_projection:
216
+ ops = []
217
+ if has_stride:
218
+ ops.append(get_matching_pool_op(conv_op=conv_op, adaptive=False, pool_type='avg')(stride, stride))
219
+ if requires_projection:
220
+ ops.append(
221
+ ConvDropoutNormReLU(conv_op, input_channels, output_channels, 1, 1, False,
222
+ norm_op, norm_op_kwargs, None, None, None, None
223
+ )
224
+ )
225
+ self.skip = nn.Sequential(*ops)
226
+ else:
227
+ self.skip = lambda x: x
228
+
229
+ def forward(self, x):
230
+ residual = self.skip(x)
231
+ out = self.conv3(self.conv2(self.conv1(x)))
232
+ if self.apply_stochastic_depth:
233
+ out = self.drop_path(out)
234
+ if self.apply_se:
235
+ out = self.squeeze_excitation(out)
236
+ out += residual
237
+ return self.nonlin3(out)
238
+
239
+ def compute_conv_feature_map_size(self, input_size):
240
+ assert len(input_size) == len(self.stride), "just give the image size without color/feature channels or " \
241
+ "batch channel. Do not give input_size=(b, c, x, y(, z)). " \
242
+ "Give input_size=(x, y(, z))!"
243
+ size_after_stride = [i // j for i, j in zip(input_size, self.stride)]
244
+ # conv1
245
+ output_size_conv1 = np.prod([self.bottleneck_channels, *input_size], dtype=np.int64)
246
+ # conv2
247
+ output_size_conv2 = np.prod([self.bottleneck_channels, *size_after_stride], dtype=np.int64)
248
+ # conv3
249
+ output_size_conv3 = np.prod([self.output_channels, *size_after_stride], dtype=np.int64)
250
+ # skip conv (if applicable)
251
+ if (self.input_channels != self.output_channels) or any([i != j for i, j in zip(input_size, size_after_stride)]):
252
+ assert isinstance(self.skip, nn.Sequential)
253
+ output_size_skip = np.prod([self.output_channels, *size_after_stride], dtype=np.int64)
254
+ else:
255
+ assert not isinstance(self.skip, nn.Sequential)
256
+ output_size_skip = 0
257
+ return output_size_conv1 + output_size_conv2 + output_size_conv3 + output_size_skip
258
+
259
+
260
+ class StackedResidualBlocks(nn.Module):
261
+ def __init__(self,
262
+ n_blocks: int,
263
+ conv_op: Type[_ConvNd],
264
+ input_channels: int,
265
+ output_channels: Union[int, List[int], Tuple[int, ...]],
266
+ kernel_size: Union[int, List[int], Tuple[int, ...]],
267
+ initial_stride: Union[int, List[int], Tuple[int, ...]],
268
+ conv_bias: bool = False,
269
+ norm_op: Union[None, Type[nn.Module]] = None,
270
+ norm_op_kwargs: dict = None,
271
+ dropout_op: Union[None, Type[_DropoutNd]] = None,
272
+ dropout_op_kwargs: dict = None,
273
+ nonlin: Union[None, Type[torch.nn.Module]] = None,
274
+ nonlin_kwargs: dict = None,
275
+ block: Union[Type[BasicBlockD], Type[BottleneckD]] = BasicBlockD,
276
+ bottleneck_channels: Union[int, List[int], Tuple[int, ...]] = None,
277
+ stochastic_depth_p: float = 0.0,
278
+ squeeze_excitation: bool = False,
279
+ squeeze_excitation_reduction_ratio: float = 1. / 16
280
+ ):
281
+ """
282
+ Stack multiple instances of block.
283
+
284
+ :param n_blocks: number of residual blocks
285
+ :param conv_op: nn.ConvNd class
286
+ :param input_channels: only relevant for forst block in the sequence. This is the input number of features.
287
+ After the first block, the number of features in the main path to which the residuals are added is output_channels
288
+ :param output_channels: number of features in the main path to which the residuals are added (and also the
289
+ number of features of the output)
290
+ :param kernel_size: kernel size for all nxn (n!=1) convolutions. Default: 3x3
291
+ :param initial_stride: only affects the first block. All subsequent blocks have stride 1
292
+ :param conv_bias: usually False
293
+ :param norm_op: nn.BatchNormNd, InstanceNormNd etc
294
+ :param norm_op_kwargs: dictionary of kwargs. Leave empty ({}) for defaults
295
+ :param dropout_op: nn.DropoutNd, can be None for no dropout
296
+ :param dropout_op_kwargs:
297
+ :param nonlin:
298
+ :param nonlin_kwargs:
299
+ :param block: BasicBlockD or BottleneckD
300
+ :param bottleneck_channels: if block is BottleneckD then we need to know the number of bottleneck features.
301
+ Bottleneck will use first 1x1 conv to reduce input to bottleneck features, then run the nxn (see kernel_size)
302
+ conv on that (bottleneck -> bottleneck). Finally the output will be projected back to output_channels
303
+ (bottleneck -> output_channels) with the final 1x1 conv
304
+ :param stochastic_depth_p: probability of applying stochastic depth in residual blocks
305
+ :param squeeze_excitation: whether to apply squeeze and excitation or not
306
+ :param squeeze_excitation_reduction_ratio: ratio by how much squeeze and excitation should reduce channels
307
+ respective to number of out channels of respective block
308
+ """
309
+ super().__init__()
310
+ assert n_blocks > 0, 'n_blocks must be > 0'
311
+ assert block in [BasicBlockD, BottleneckD], 'block must be BasicBlockD or BottleneckD'
312
+ if not isinstance(output_channels, (tuple, list)):
313
+ output_channels = [output_channels] * n_blocks
314
+ if not isinstance(bottleneck_channels, (tuple, list)):
315
+ bottleneck_channels = [bottleneck_channels] * n_blocks
316
+
317
+ if block == BasicBlockD:
318
+ blocks = nn.Sequential(
319
+ block(conv_op, input_channels, output_channels[0], kernel_size, initial_stride, conv_bias,
320
+ norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs, stochastic_depth_p,
321
+ squeeze_excitation, squeeze_excitation_reduction_ratio),
322
+ *[block(conv_op, output_channels[n - 1], output_channels[n], kernel_size, 1, conv_bias, norm_op,
323
+ norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs, stochastic_depth_p,
324
+ squeeze_excitation, squeeze_excitation_reduction_ratio) for n in range(1, n_blocks)]
325
+ )
326
+ else:
327
+ blocks = nn.Sequential(
328
+ block(conv_op, input_channels, bottleneck_channels[0], output_channels[0], kernel_size,
329
+ initial_stride, conv_bias, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs,
330
+ nonlin, nonlin_kwargs, stochastic_depth_p, squeeze_excitation, squeeze_excitation_reduction_ratio),
331
+ *[block(conv_op, output_channels[n - 1], bottleneck_channels[n], output_channels[n], kernel_size,
332
+ 1, conv_bias, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs,
333
+ nonlin, nonlin_kwargs, stochastic_depth_p, squeeze_excitation,
334
+ squeeze_excitation_reduction_ratio) for n in range(1, n_blocks)]
335
+ )
336
+ self.blocks = blocks
337
+ self.initial_stride = maybe_convert_scalar_to_list(conv_op, initial_stride)
338
+ self.output_channels = output_channels[-1]
339
+
340
+ def forward(self, x):
341
+ return self.blocks(x)
342
+
343
+ def compute_conv_feature_map_size(self, input_size):
344
+ assert len(input_size) == len(self.initial_stride), "just give the image size without color/feature channels or " \
345
+ "batch channel. Do not give input_size=(b, c, x, y(, z)). " \
346
+ "Give input_size=(x, y(, z))!"
347
+ output = self.blocks[0].compute_conv_feature_map_size(input_size)
348
+ size_after_stride = [i // j for i, j in zip(input_size, self.initial_stride)]
349
+ for b in self.blocks[1:]:
350
+ output += b.compute_conv_feature_map_size(size_after_stride)
351
+ return output
352
+
353
+
354
+ if __name__ == '__main__':
355
+ data = torch.rand((1, 3, 40, 32))
356
+
357
+ stx = StackedResidualBlocks(2, nn.Conv2d, 24, (16, 16), (3, 3), (1, 2),
358
+ norm_op=nn.BatchNorm2d, nonlin=nn.ReLU, nonlin_kwargs={'inplace': True},
359
+ block=BottleneckD, bottleneck_channels=3)
360
+ model = nn.Sequential(ConvDropoutNormReLU(nn.Conv2d,
361
+ 3, 24, 3, 1, True, nn.BatchNorm2d, {}, None, None, nn.LeakyReLU,
362
+ {'inplace': True}),
363
+ stx)
364
+ import hiddenlayer as hl
365
+
366
+ g = hl.build_graph(model, data,
367
+ transforms=None)
368
+ g.save("network_architecture.pdf")
369
+ del g
370
+
371
+ print(stx.compute_conv_feature_map_size((40, 32)))