VBoussot commited on
Commit
49b57ea
·
1 Parent(s): c1770b3

Add body CT

Browse files
body/Evaluation.yml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Evaluator:
2
+ metrics:
3
+ Output:
4
+ targets_criterions:
5
+ Reference;Mask:
6
+ criterions_loader:
7
+ DiceSaveMap:
8
+ labels: None
9
+ dataset: ./Evaluations/TotalSegmentator/Output:mha
10
+ group: Seg_MAE_map
11
+ Dataset:
12
+ groups_src:
13
+ Mask_0:
14
+ groups_dest:
15
+ Mask:
16
+ transforms: None
17
+ Volume_0:
18
+ groups_dest:
19
+ Output:
20
+ transforms: None
21
+ Reference_0:
22
+ groups_dest:
23
+ Reference:
24
+ transforms: None
25
+ subset: None
26
+ dataset_filenames:
27
+ - ./Dataset:a:mha
28
+ validation: None
29
+ train_name: TotalSegmentator
body/M299.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8fb6a978453e6f8d2cd8ec42500f7946b53648db8043c6f2504c042910a366c5
3
+ size 124804837
body/Prediction.yml ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Predictor:
2
+ Model:
3
+ classpath: model:Unet_TS_CT
4
+ Unet_TS_CT:
5
+ outputs_criterions: None
6
+ channels:
7
+ - 1
8
+ - 32
9
+ - 64
10
+ - 128
11
+ - 256
12
+ - 320
13
+ - 320
14
+ Dataset:
15
+ groups_src:
16
+ Volume_0:
17
+ groups_dest:
18
+ Volume:
19
+ transforms:
20
+ TensorCast:
21
+ dtype: float32
22
+ inverse: false
23
+ Canonical:
24
+ inverse: true
25
+ ResampleToResolution:
26
+ spacing:
27
+ - 1.5
28
+ - 1.5
29
+ - 1.5
30
+ inverse: true
31
+ Padding:
32
+ padding:
33
+ - 32
34
+ - 32
35
+ - 32
36
+ - 32
37
+ - 32
38
+ - 32
39
+ mode: constant
40
+ inverse: true
41
+ patch_transforms: None
42
+ is_input: true
43
+ augmentations: None
44
+ Patch:
45
+ patch_size:
46
+ - 96
47
+ - 128
48
+ - 160
49
+ overlap: 32
50
+ mask: None
51
+ pad_value: 0
52
+ extend_slice: 0
53
+ subset: None
54
+ filter: None
55
+ dataset_filenames:
56
+ - ./Dataset/:mha
57
+ use_cache: false
58
+ batch_size: 1
59
+ outputs_dataset:
60
+ Head:Softmax:
61
+ OutputDataset:
62
+ name_class: OutSameAsGroupDataset
63
+ before_reduction_transforms:
64
+ Softmax:
65
+ dim: 0
66
+ Argmax:
67
+ dim: 0
68
+ TensorCast:
69
+ dtype: uint8
70
+ inverse: true
71
+ after_reduction_transforms:
72
+ Sum:
73
+ dim: 0
74
+ final_transforms: None
75
+ dataset_filename: Output:mha
76
+ group: Output
77
+ same_as_group: Volume_0:Volume
78
+ patch_combine: Cosinus
79
+ inverse_transform: true
80
+ reduction: Concat
81
+ train_name: TotalSegmentator
82
+ manual_seed: 32
83
+ gpu_checkpoints: None
84
+ images_log: None
85
+ combine: Concat
86
+ autocast: true
87
+ data_log: None
body/app.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "display_name": "Segmentation: Body CT",
3
+ "short_description": "<b>Description:</b><br>KonfAI adaptation of the <a href='https://github.com/wasserth/TotalSegmentator'>TotalSegmentator</a> body CT model for fast segmentation into <b>body trunk</b> and <b>extremities</b> at <b>1.5 mm isotropic</b> resolution.<br><br><b>How to cite:</b><br><cite>J. Wasserthal et al., <i>TotalSegmentator: Robust Segmentation of 104 Anatomical Structures in CT Images</i>, Radiology: AI, 2023.</cite>",
4
+ "description": "<b>Description:</b><br>This model is a <b>KonfAI</b> adaptation of the <a href='https://github.com/wasserth/TotalSegmentator'>TotalSegmentator</a> body CT segmentation model.<br><br><b>Capabilities:</b><br>&bull; Segmentation of <b>2 body classes</b>: body_trunc and body_extremities<br>&bull; CT input resampled to <b>1.5 mm isotropic</b><br>&bull; Lightweight body mask generation for preprocessing and downstream workflows<br><br><b>Training data:</b><br>Trained on <b>1559 CT scans</b>, with labels for the body trunk and extremities<br><br><b>How to cite:</b><br><cite>J. Wasserthal et al., <i>TotalSegmentator: Robust Segmentation of 104 Anatomical Structures in CT Images</i>, Radiology: AI, 2023.</cite>",
5
+ "tta": 0,
6
+ "mc_dropout": 0,
7
+ "models": ["M299.pt"],
8
+ "inputs": {
9
+ "Volume": {
10
+ "display_name": "Input Volume",
11
+ "volume_type": "VOLUME",
12
+ "required": true
13
+ }
14
+ },
15
+ "outputs": {
16
+ "Segmentation": {
17
+ "display_name": "Segmentation",
18
+ "volume_type": "SEGMENTATION",
19
+ "required": true
20
+ }
21
+ },
22
+ "inputs_evaluations": {
23
+ "Image": {
24
+ "Evaluation.yml": {
25
+ "Segmentation": {
26
+ "display_name": "Output Segmentation",
27
+ "volume_type": "VOLUME",
28
+ "required": true
29
+ },
30
+ "GT_Segmentation": {
31
+ "display_name": "GT Segmentation",
32
+ "volume_type": "VOLUME",
33
+ "required": true
34
+ }
35
+ }
36
+ }
37
+ },
38
+ "terminology": {
39
+ "1": { "name": "body_trunc", "color": "#4B79EA" },
40
+ "2": { "name": "body_extremities", "color": "#A8EA4B" }
41
+ }
42
+ }
body/model.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from konfai.network import network, blocks
3
+
4
+ class ConvBlock(network.ModuleArgsDict):
5
+ def __init__(self, in_channels : int, out_channels : int, stride: int = 1 ) -> None:
6
+ super().__init__()
7
+ self.add_module("Conv_0", torch.nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=stride, padding=1, bias=True))
8
+ self.add_module("Norm_0", torch.nn.InstanceNorm3d(num_features=out_channels, affine=True))
9
+ self.add_module("Activation_0", torch.nn.LeakyReLU(negative_slope=0.01, inplace=True))
10
+ self.add_module("Conv_1", torch.nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=True))
11
+ self.add_module("Norm_1", torch.nn.InstanceNorm3d(num_features=out_channels, affine=True))
12
+ self.add_module("Activation_1", torch.nn.LeakyReLU(negative_slope=0.01, inplace=True))
13
+
14
+ class UNetHead(network.ModuleArgsDict):
15
+ def __init__(self, in_channels: int, nb_class: int) -> None:
16
+ super().__init__()
17
+ self.add_module("Conv", torch.nn.Conv3d(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0))
18
+ self.add_module("Softmax", torch.nn.Softmax(dim=1))
19
+
20
+ class UNetBlock(network.ModuleArgsDict):
21
+
22
+ def __init__(self, channels, i : int = 0) -> None:
23
+ super().__init__()
24
+ self.add_module("DownConvBlock", ConvBlock(in_channels=channels[0], out_channels=channels[1], stride= 2 if i>0 else 1))
25
+
26
+ if len(channels) > 2:
27
+ self.add_module("UNetBlock", UNetBlock(channels[1:], i+1))
28
+ self.add_module("UpConvBlock", ConvBlock(in_channels=channels[1]*2, out_channels=channels[1]))
29
+
30
+ if i > 0:
31
+ self.add_module("CONV_TRANSPOSE", torch.nn.ConvTranspose3d(in_channels = channels[1], out_channels = channels[0], kernel_size = 2, stride = 2, padding = 0))
32
+ self.add_module("SkipConnection", blocks.Concat(), in_branch=[0, 1])
33
+
34
+ class ClipAndNormalize(torch.nn.Module):
35
+
36
+ def __init__(self) -> None:
37
+ super().__init__()
38
+ self.register_buffer("clip_min", torch.empty(1))
39
+ self.register_buffer("clip_max", torch.empty(1))
40
+ self.register_buffer("mean", torch.empty(1))
41
+ self.register_buffer("std", torch.empty(1))
42
+
43
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
44
+ x = torch.clamp(x, self.clip_min, self.clip_max)
45
+ return (x - self.mean) / (self.std)
46
+
47
+ class Unet_TS_CT(network.Network):
48
+
49
+ def __init__(self,
50
+ optimizer: network.OptimizerLoader = network.OptimizerLoader(),
51
+ schedulers: dict[str, network.LRSchedulersLoader] = {
52
+ "default:ReduceLROnPlateau": network.LRSchedulersLoader(0)
53
+ },
54
+ outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
55
+ channels: list[int] = [1, 32, 64, 128, 320, 320]) -> None:
56
+ super().__init__(
57
+ in_channels=channels[0],
58
+ optimizer=optimizer,
59
+ schedulers=schedulers,
60
+ outputs_criterions=outputs_criterions,
61
+ patch=None,
62
+ dim=3,
63
+ )
64
+ self.add_module("ClipAndNormalize", ClipAndNormalize())
65
+ self.add_module("UNetBlock", UNetBlock(channels))
66
+ self.add_module("Head", UNetHead(channels[1], 118))
67
+
68
+ def load(
69
+ self,
70
+ state_dict: dict[str, dict[str, torch.Tensor] | int],
71
+ init: bool = True,
72
+ ema: bool = False,
73
+ ):
74
+ nb_class, in_channels = state_dict["Model"]["Unet_TS_CT"]["Head.Conv.weight"].shape[:2]
75
+ self["Head"].add_module("Conv", torch.nn.Conv3d(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0))
76
+ super().load(state_dict, init, ema)
body/requirements.txt ADDED
File without changes
build.py CHANGED
@@ -94,15 +94,17 @@ models = {
94
  "M294" : (UnetCPP_1_CT(), "total", 1, False, url+"v2.0.0-weights/Dataset294_TotalSegmentator_part4_muscles_1559subj.zip"),
95
  "M295" : (UnetCPP_1_CT(), "total", 1, False, url+"v2.0.0-weights/Dataset295_TotalSegmentator_part5_ribs_1559subj.zip"),
96
  "M297" : (UnetCPP_2_CT(), "total-3mm", 2, False, url+"v2.0.4-weights/Dataset297_TotalSegmentator_total_3mm_1559subj_v204.zip"),
97
- #"M298" : (UnetCPP_2_CT(), 2, False, url+"v2.0.0-weights/Dataset298_TotalSegmentator_total_6mm_1559subj.zip"),
98
- #"M730" : (UnetCPP_1_MR(), True, 1, url+"v2.2.0-weights/Dataset730_TotalSegmentatorMRI_part1_organs_495subj.zip"),
99
- #"M731" : (UnetCPP_1_MR(), True, 1, url+"v2.2.0-weights/Dataset731_TotalSegmentatorMRI_part2_muscles_495subj.zip"),
100
- #"M732" : (UnetCPP_2_MR(), False, 2, url+"v2.2.0-weights/Dataset732_TotalSegmentatorMRI_total_3mm_495subj.zip"),
101
- #"M733" : (UnetCPP_3_MR(), False, 3, url+"v2.2.0-weights/Dataset733_TotalSegmentatorMRI_total_6mm_495subj.zip"),
102
  "M850" : (UnetCPP_1_MR(), "total_mr", 1, True, url+"v2.5.0-weights/Dataset850_TotalSegMRI_part1_organs_1088subj.zip"),
103
  "M851" : (UnetCPP_1_MR(), "total_mr", 1, True, url+"v2.5.0-weights/Dataset851_TotalSegMRI_part2_muscles_1088subj.zip"),
104
  "M852" : (UnetCPP_2_MR(), "total_mr-3mm", 2, True, url+"v2.5.0-weights/Dataset852_TotalSegMRI_total_3mm_1088subj.zip"),
105
- #"M853" : (UnetCPP_3_MR(), False, 3, url+"v2.5.0-weights/Dataset853_TotalSegMRI_total_6mm_1088subj.zip")
 
 
106
  }
107
  if __name__ == "__main__":
108
  for name, model in models.items():
 
94
  "M294" : (UnetCPP_1_CT(), "total", 1, False, url+"v2.0.0-weights/Dataset294_TotalSegmentator_part4_muscles_1559subj.zip"),
95
  "M295" : (UnetCPP_1_CT(), "total", 1, False, url+"v2.0.0-weights/Dataset295_TotalSegmentator_part5_ribs_1559subj.zip"),
96
  "M297" : (UnetCPP_2_CT(), "total-3mm", 2, False, url+"v2.0.4-weights/Dataset297_TotalSegmentator_total_3mm_1559subj_v204.zip"),
97
+ "M298" : (UnetCPP_2_CT(), 2, False, url+"v2.0.0-weights/Dataset298_TotalSegmentator_total_6mm_1559subj.zip"),
98
+ "M730" : (UnetCPP_1_MR(), True, 1, url+"v2.2.0-weights/Dataset730_TotalSegmentatorMRI_part1_organs_495subj.zip"),
99
+ "M731" : (UnetCPP_1_MR(), True, 1, url+"v2.2.0-weights/Dataset731_TotalSegmentatorMRI_part2_muscles_495subj.zip"),
100
+ "M732" : (UnetCPP_2_MR(), False, 2, url+"v2.2.0-weights/Dataset732_TotalSegmentatorMRI_total_3mm_495subj.zip"),
101
+ "M733" : (UnetCPP_3_MR(), False, 3, url+"v2.2.0-weights/Dataset733_TotalSegmentatorMRI_total_6mm_495subj.zip"),
102
  "M850" : (UnetCPP_1_MR(), "total_mr", 1, True, url+"v2.5.0-weights/Dataset850_TotalSegMRI_part1_organs_1088subj.zip"),
103
  "M851" : (UnetCPP_1_MR(), "total_mr", 1, True, url+"v2.5.0-weights/Dataset851_TotalSegMRI_part2_muscles_1088subj.zip"),
104
  "M852" : (UnetCPP_2_MR(), "total_mr-3mm", 2, True, url+"v2.5.0-weights/Dataset852_TotalSegMRI_total_3mm_1088subj.zip"),
105
+ "M853" : (UnetCPP_3_MR(), False, 3, url+"v2.5.0-weights/Dataset853_TotalSegMRI_total_6mm_1088subj.zip"),
106
+ "M299" : (UnetCPP_1_CT(), "total", 1, False, url+"v2.0.0-weights/Dataset299_body_1559subj.zip"),
107
+ "M597" : (UnetCPP_1_MR(), "body", 1, True, url+"v2.5.0-weights/Dataset597_mri_body_139subj.zip"),
108
  }
109
  if __name__ == "__main__":
110
  for name, model in models.items():