Valentin Boussot commited on
Commit
7a52c73
·
1 Parent(s): 6fe049a

Fix clipping and normalization in CT-specific model preprocessing

Browse files
Build.py → build.py RENAMED
@@ -6,15 +6,20 @@ import shutil
6
  from pathlib import Path
7
  import os
8
  from functools import partial
9
- from Model import Unet_TS
 
 
10
 
11
- def convert_torchScript_full(model_name: str, model: torch.nn.Module, type: int, url: str) -> None:
12
- state_dict = download(url)
13
  tmp = {}
14
  with open("Destination_Unet_{}.txt".format(type)) as f2:
15
  it = iter(state_dict.keys())
16
  for l1 in f2:
 
 
17
  key = next(it)
 
18
  while "decoder.seg_layers" in key:
19
  if type == 1:
20
  if "decoder.seg_layers.4" in key :
@@ -30,17 +35,24 @@ def convert_torchScript_full(model_name: str, model: torch.nn.Module, type: int,
30
  while "all_modules" in key or "decoder.encoder" in key:
31
  key = next(it)
32
  tmp[l1.replace("\n", "")] = state_dict[key]
 
 
 
 
 
 
 
 
 
 
33
 
34
- model.load_state_dict(tmp)
35
- torch.save({"Model" : {"Unet_TS" : tmp}}, f"{model_name}.pt")
36
-
37
- def download(url: str) -> dict[str, torch.Tensor]:
38
  with open(url.split("/")[-1], 'wb') as f:
39
  with requests.get(url, stream=True) as r:
40
  r.raise_for_status()
41
 
42
  total_size = int(r.headers.get('content-length', 0))
43
- progress_bar = tqdm(total=total_size, unit='B', unit_scale=True, desc="Downloading")
44
  for chunk in r.iter_content(chunk_size=8192 * 16):
45
  progress_bar.update(len(chunk))
46
  f.write(chunk)
@@ -48,33 +60,50 @@ def download(url: str) -> dict[str, torch.Tensor]:
48
  with zipfile.ZipFile(url.split("/")[-1], 'r') as zip_f:
49
  zip_f.extractall(url.split("/")[-1].replace(".zip", ""))
50
  os.remove(url.split("/")[-1])
51
- state_dict = torch.load(next(Path(url.split("/")[-1].replace(".zip", "")).rglob("checkpoint_final.pth"), None), weights_only=False)["network_weights"]
52
- shutil.rmtree(url.split("/")[-1].replace(".zip", ""))
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  return state_dict
54
 
55
  url = "https://github.com/wasserth/TotalSegmentator/releases/download/"
56
 
57
- UnetCPP_1 = partial(Unet_TS, channels = [1,32,64,128,256,320,320])
58
- UnetCPP_2 = partial(Unet_TS, channels = [1,32,64,128,256,320])
59
- UnetCPP_3 = partial(Unet_TS, channels = [1,32,64,128,256])
60
 
61
- models = {
62
- "M291" : (UnetCPP_1(nb_class=25), 1, url+"v2.0.0-weights/Dataset291_TotalSegmentator_part1_organs_1559subj.zip"),
63
- "M292" : (UnetCPP_1(nb_class=27), 1, url+"v2.0.0-weights/Dataset292_TotalSegmentator_part2_vertebrae_1532subj.zip"),
64
- "M293" : (UnetCPP_1(nb_class=19), 1, url+"v2.0.0-weights/Dataset293_TotalSegmentator_part3_cardiac_1559subj.zip"),
65
- "M294" : (UnetCPP_1(nb_class=24), 1, url+"v2.0.0-weights/Dataset294_TotalSegmentator_part4_muscles_1559subj.zip"),
66
- "M295" : (UnetCPP_1(nb_class=27), 1, url+"v2.0.0-weights/Dataset295_TotalSegmentator_part5_ribs_1559subj.zip"),
67
- "M297" : (UnetCPP_2(nb_class=118), 2, url+"v2.0.4-weights/Dataset297_TotalSegmentator_total_3mm_1559subj_v204.zip"),
68
- "M298" : (UnetCPP_2(nb_class=118), 2, url+"v2.0.0-weights/Dataset298_TotalSegmentator_total_6mm_1559subj.zip"),
69
- "M730" : (UnetCPP_1(nb_class=30, mri = True), 1, url+"v2.2.0-weights/Dataset730_TotalSegmentatorMRI_part1_organs_495subj.zip"),
70
- "M731" : (UnetCPP_1(nb_class=28, mri = True), 1, url+"v2.2.0-weights/Dataset731_TotalSegmentatorMRI_part2_muscles_495subj.zip"),
71
- "M732" : (UnetCPP_2(nb_class=57), 2, url+"v2.2.0-weights/Dataset732_TotalSegmentatorMRI_total_3mm_495subj.zip"),
72
- "M733" : (UnetCPP_3(nb_class=57), 3, url+"v2.2.0-weights/Dataset733_TotalSegmentatorMRI_total_6mm_495subj.zip"),
73
- "M850" : (UnetCPP_1(nb_class=30, mri = True), 1, url+"v2.5.0-weights/Dataset850_TotalSegMRI_part1_organs_1088subj.zip"),
74
- "M851" : (UnetCPP_1(nb_class=22, mri = True), 1, url+"v2.5.0-weights/Dataset851_TotalSegMRI_part2_muscles_1088subj.zip"),
75
- "M852" : (UnetCPP_2(nb_class=51), 2, url+"v2.5.0-weights/Dataset852_TotalSegMRI_total_3mm_1088subj.zip"),
76
- "M853" : (UnetCPP_3(nb_class=51), 3, url+"v2.5.0-weights/Dataset853_TotalSegMRI_total_6mm_1088subj.zip")}
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  if __name__ == "__main__":
79
  for name, model in models.items():
80
- convert_torchScript_full(name, model[0], model[1], model[2])
 
6
  from pathlib import Path
7
  import os
8
  from functools import partial
9
+ from total.model import Unet_TS_CT
10
+ from total_mr.model import Unet_TS_MR
11
+ import json
12
 
13
+ def convert_torchScript_full(model_name: str, model: torch.nn.Module, task: str, type: int, mri: bool, url: str) -> None:
14
+ state_dict = download(url, model_name, mri)
15
  tmp = {}
16
  with open("Destination_Unet_{}.txt".format(type)) as f2:
17
  it = iter(state_dict.keys())
18
  for l1 in f2:
19
+ print(l1)
20
+
21
  key = next(it)
22
+ print(key)
23
  while "decoder.seg_layers" in key:
24
  if type == 1:
25
  if "decoder.seg_layers.4" in key :
 
35
  while "all_modules" in key or "decoder.encoder" in key:
36
  key = next(it)
37
  tmp[l1.replace("\n", "")] = state_dict[key]
38
+ if not mri:
39
+ tmp["ClipAndNormalize.mean"] = state_dict["mean"]
40
+ tmp["ClipAndNormalize.std"] = state_dict["std"]
41
+ tmp["ClipAndNormalize.clip_min"] = state_dict["percentile_00_5"]
42
+ tmp["ClipAndNormalize.clip_max"] = state_dict["percentile_99_5"]
43
+ state_dict = {"Model" : {model.name : tmp}}
44
+ model.load(state_dict)
45
+ dest_path = Path(f"./{task}")
46
+ dest_path.mkdir(exist_ok=True)
47
+ torch.save(state_dict, str(dest_path/f"{model_name}.pt"))
48
 
49
+ def download(url: str, model_name: str, mri: bool) -> dict[str, torch.Tensor]:
 
 
 
50
  with open(url.split("/")[-1], 'wb') as f:
51
  with requests.get(url, stream=True) as r:
52
  r.raise_for_status()
53
 
54
  total_size = int(r.headers.get('content-length', 0))
55
+ progress_bar = tqdm(total=total_size, unit='B', unit_scale=True, desc=f"Downloading {model_name}")
56
  for chunk in r.iter_content(chunk_size=8192 * 16):
57
  progress_bar.update(len(chunk))
58
  f.write(chunk)
 
60
  with zipfile.ZipFile(url.split("/")[-1], 'r') as zip_f:
61
  zip_f.extractall(url.split("/")[-1].replace(".zip", ""))
62
  os.remove(url.split("/")[-1])
63
+ zip_path = Path(url.split("/")[-1].replace(".zip", ""))
64
+
65
+ state_dict = torch.load(next(zip_path.rglob("checkpoint_final.pth"), None), map_location="cpu", weights_only=False)["network_weights"]
66
+ if not mri:
67
+ dataset_fingerprint_path = next(zip_path.rglob("dataset_fingerprint.json"), None)
68
+ with open(dataset_fingerprint_path, "r") as f:
69
+ data = json.load(f)
70
+
71
+ ch0 = data["foreground_intensity_properties_per_channel"]["0"]
72
+
73
+ state_dict["mean"] = torch.tensor([ch0["mean"]])
74
+ state_dict["std"] = torch.tensor([ch0["std"]])
75
+ state_dict["percentile_00_5"] = torch.tensor([ch0["percentile_00_5"]])
76
+ state_dict["percentile_99_5"] = torch.tensor([ch0["percentile_99_5"]])
77
+ shutil.rmtree(zip_path)
78
  return state_dict
79
 
80
  url = "https://github.com/wasserth/TotalSegmentator/releases/download/"
81
 
82
+ UnetCPP_1_CT = partial(Unet_TS_CT, channels = [1,32,64,128,256,320,320])
83
+ UnetCPP_2_CT = partial(Unet_TS_CT, channels = [1,32,64,128,256,320])
84
+ UnetCPP_3_CT = partial(Unet_TS_CT, channels = [1,32,64,128,256])
85
 
86
+ UnetCPP_1_MR = partial(Unet_TS_MR, channels = [1,32,64,128,256,320,320])
87
+ UnetCPP_2_MR = partial(Unet_TS_MR, channels = [1,32,64,128,256,320])
88
+ UnetCPP_3_MR = partial(Unet_TS_MR, channels = [1,32,64,128,256])
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
+ models = {
91
+ #"M291" : (UnetCPP_1_CT(), "total", 1, False, url+"v2.0.0-weights/Dataset291_TotalSegmentator_part1_organs_1559subj.zip"),
92
+ #"M292" : (UnetCPP_1_CT(), "total", 1, False, url+"v2.0.0-weights/Dataset292_TotalSegmentator_part2_vertebrae_1532subj.zip"),
93
+ #"M293" : (UnetCPP_1_CT(), "total", 1, False, url+"v2.0.0-weights/Dataset293_TotalSegmentator_part3_cardiac_1559subj.zip"),
94
+ #"M294" : (UnetCPP_1_CT(), "total", 1, False, url+"v2.0.0-weights/Dataset294_TotalSegmentator_part4_muscles_1559subj.zip"),
95
+ #"M295" : (UnetCPP_1_CT(), "total", 1, False, url+"v2.0.0-weights/Dataset295_TotalSegmentator_part5_ribs_1559subj.zip"),
96
+ #"M297" : (UnetCPP_2_CT(), "total-3mm", 2, False, url+"v2.0.4-weights/Dataset297_TotalSegmentator_total_3mm_1559subj_v204.zip"),
97
+ #"M298" : (UnetCPP_2_CT(), 2, False, url+"v2.0.0-weights/Dataset298_TotalSegmentator_total_6mm_1559subj.zip"),
98
+ #"M730" : (UnetCPP_1_MR(), True, 1, url+"v2.2.0-weights/Dataset730_TotalSegmentatorMRI_part1_organs_495subj.zip"),
99
+ #"M731" : (UnetCPP_1_MR(), True, 1, url+"v2.2.0-weights/Dataset731_TotalSegmentatorMRI_part2_muscles_495subj.zip"),
100
+ #"M732" : (UnetCPP_2_MR(), False, 2, url+"v2.2.0-weights/Dataset732_TotalSegmentatorMRI_total_3mm_495subj.zip"),
101
+ #"M733" : (UnetCPP_3_MR(), False, 3, url+"v2.2.0-weights/Dataset733_TotalSegmentatorMRI_total_6mm_495subj.zip"),
102
+ #"M850" : (UnetCPP_1_MR(), "total_mr", 1, True, url+"v2.5.0-weights/Dataset850_TotalSegMRI_part1_organs_1088subj.zip"),
103
+ #"M851" : (UnetCPP_1_MR(), "total_mr", 1, True, url+"v2.5.0-weights/Dataset851_TotalSegMRI_part2_muscles_1088subj.zip"),
104
+ "M852" : (UnetCPP_2_MR(), "total_mr-3mm", 2, True, url+"v2.5.0-weights/Dataset852_TotalSegMRI_total_3mm_1088subj.zip"),
105
+ #"M853" : (UnetCPP_3_MR(), False, 3, url+"v2.5.0-weights/Dataset853_TotalSegMRI_total_6mm_1088subj.zip")
106
+ }
107
  if __name__ == "__main__":
108
  for name, model in models.items():
109
+ convert_torchScript_full(name, *model)
total-3mm/M297.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:6490a06ac9d242757af99fc674ba8f74ecb5c62009ecda6c62cda217b9cbbcb7
3
- size 66225317
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d50dab1ce25e6d55f9e7b2f6a5a6282293e6c9add63278558149e8a1146f5f5b
3
+ size 66225357
total-3mm/Prediction.yml CHANGED
@@ -1,7 +1,7 @@
1
  Predictor:
2
  Model:
3
- classpath: Model:Unet_TS
4
- Unet_TS:
5
  outputs_criterions: None
6
  channels:
7
  - 1
@@ -10,7 +10,6 @@ Predictor:
10
  - 128
11
  - 256
12
  - 320
13
- mri: false
14
  Dataset:
15
  groups_src:
16
  Volume_0:
@@ -22,24 +21,18 @@ Predictor:
22
  inverse: false
23
  Canonical:
24
  inverse: true
25
- Clip:
26
- min_value: -1024
27
- max_value: 276
28
- save_clip_min: false
29
- save_clip_max: false
30
- mask: None
31
- Standardize:
32
- lazy: false
33
- mean: -370.00039267657144
34
- std: 436.5998675471528
35
- mask: None
36
- inverse: false
37
  ResampleToResolution:
38
  spacing:
39
  - 3
40
  - 3
41
  - 3
42
  inverse: true
 
 
 
 
 
 
43
  Padding:
44
  padding:
45
  - 32
 
1
  Predictor:
2
  Model:
3
+ classpath: model:Unet_TS_CT
4
+ Unet_TS_CT:
5
  outputs_criterions: None
6
  channels:
7
  - 1
 
10
  - 128
11
  - 256
12
  - 320
 
13
  Dataset:
14
  groups_src:
15
  Volume_0:
 
21
  inverse: false
22
  Canonical:
23
  inverse: true
 
 
 
 
 
 
 
 
 
 
 
 
24
  ResampleToResolution:
25
  spacing:
26
  - 3
27
  - 3
28
  - 3
29
  inverse: true
30
+ Standardize:
31
+ lazy: false
32
+ mean: None
33
+ std: None
34
+ mask: None
35
+ inverse: false
36
  Padding:
37
  padding:
38
  - 32
total-3mm/app.json CHANGED
@@ -1,5 +1,5 @@
1
  {
2
- "display_name": "TotalSegmentator 3mm",
3
  "short_description": "<b>Description:</b><br>Lightweight KonfAI adaptation of <a href='https://github.com/wasserth/TotalSegmentator'>TotalSegmentator</a> trained at <b>3 mm resolution</b>, reducing GPU/RAM requirements while segmenting <b>118 anatomical structures</b> in whole-body CT.<br><br><b>How to cite:</b><br><cite>J. Wasserthal et al., <i>TotalSegmentator: Robust Segmentation of 104 Anatomical Structures in CT Images</i>, Radiology: AI, 2023.</cite>",
4
  "description": "<b>Description:</b><br>KonfAI-optimized version of the original nnU-Net-based TotalSegmentator 3 mm model.<br><br><b>Capabilities:</b><br>• Whole-body CT segmentation of <b>118 structures</b> (organs, bones, muscles, vessels)<br>• Reduced computational footprint for lower memory and faster throughput<br>• <b>3 mm isotropic</b> inference for easier deployment on large datasets<br><br><b>Training data:</b><br>Trained on <b>1204 clinically-derived CT scans</b> with strong diversity in contrast phases, scanner types and pathologies, with expert-reviewed manual annotations<br><br><b>How to cite:</b><br><cite>J. Wasserthal et al., <i>TotalSegmentator: Robust Segmentation of 104 Anatomical Structures in CT Images</i>, Radiology: AI, 2023.</cite>",
5
  "tta": 0,
 
1
  {
2
+ "display_name": "Segmentation: TotalSegmentator 3mm",
3
  "short_description": "<b>Description:</b><br>Lightweight KonfAI adaptation of <a href='https://github.com/wasserth/TotalSegmentator'>TotalSegmentator</a> trained at <b>3 mm resolution</b>, reducing GPU/RAM requirements while segmenting <b>118 anatomical structures</b> in whole-body CT.<br><br><b>How to cite:</b><br><cite>J. Wasserthal et al., <i>TotalSegmentator: Robust Segmentation of 104 Anatomical Structures in CT Images</i>, Radiology: AI, 2023.</cite>",
4
  "description": "<b>Description:</b><br>KonfAI-optimized version of the original nnU-Net-based TotalSegmentator 3 mm model.<br><br><b>Capabilities:</b><br>• Whole-body CT segmentation of <b>118 structures</b> (organs, bones, muscles, vessels)<br>• Reduced computational footprint for lower memory and faster throughput<br>• <b>3 mm isotropic</b> inference for easier deployment on large datasets<br><br><b>Training data:</b><br>Trained on <b>1204 clinically-derived CT scans</b> with strong diversity in contrast phases, scanner types and pathologies, with expert-reviewed manual annotations<br><br><b>How to cite:</b><br><cite>J. Wasserthal et al., <i>TotalSegmentator: Robust Segmentation of 104 Anatomical Structures in CT Images</i>, Radiology: AI, 2023.</cite>",
5
  "tta": 0,
total/Model.py → total-3mm/model.py RENAMED
@@ -1,6 +1,5 @@
1
  import torch
2
  from konfai.network import network, blocks
3
- from konfai.predictor import Reduction
4
 
5
  class ConvBlock(network.ModuleArgsDict):
6
  def __init__(self, in_channels : int, out_channels : int, stride: int = 1 ) -> None:
@@ -20,19 +19,32 @@ class UNetHead(network.ModuleArgsDict):
20
 
21
  class UNetBlock(network.ModuleArgsDict):
22
 
23
- def __init__(self, channels, mri: bool, i : int = 0) -> None:
24
  super().__init__()
25
- self.add_module("DownConvBlock", ConvBlock(in_channels=channels[0], out_channels=channels[1], stride= ((1,2,2) if mri and i > 4 else 2) if i>0 else 1))
26
 
27
  if len(channels) > 2:
28
- self.add_module("UNetBlock", UNetBlock(channels[1:], mri, i+1))
29
  self.add_module("UpConvBlock", ConvBlock(in_channels=channels[1]*2, out_channels=channels[1]))
30
 
31
  if i > 0:
32
- self.add_module("CONV_TRANSPOSE", torch.nn.ConvTranspose3d(in_channels = channels[1], out_channels = channels[0], kernel_size = (1,2,2) if mri and i > 4 else 2, stride = (1,2,2) if mri and i > 4 else 2, padding = 0))
33
  self.add_module("SkipConnection", blocks.Concat(), in_branch=[0, 1])
34
 
35
- class Unet_TS(network.Network):
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  def __init__(self,
38
  optimizer: network.OptimizerLoader = network.OptimizerLoader(),
@@ -40,8 +52,7 @@ class Unet_TS(network.Network):
40
  "default:ReduceLROnPlateau": network.LRSchedulersLoader(0)
41
  },
42
  outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
43
- channels = [1, 32, 64, 128, 320, 320],
44
- mri: bool = False) -> None:
45
  super().__init__(
46
  in_channels=channels[0],
47
  optimizer=optimizer,
@@ -50,8 +61,9 @@ class Unet_TS(network.Network):
50
  patch=None,
51
  dim=3,
52
  )
53
- self.add_module("UNetBlock", UNetBlock(channels, mri))
54
- self.add_module("Head", UNetHead(channels[1], 42))
 
55
 
56
  def load(
57
  self,
@@ -59,6 +71,6 @@ class Unet_TS(network.Network):
59
  init: bool = True,
60
  ema: bool = False,
61
  ):
62
- nb_class, in_channels = state_dict["Model"]["Unet_TS"]["Head.Conv.weight"].shape[:2]
63
  self["Head"].add_module("Conv", torch.nn.Conv3d(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0))
64
- super().load(state_dict, init, ema)
 
1
  import torch
2
  from konfai.network import network, blocks
 
3
 
4
  class ConvBlock(network.ModuleArgsDict):
5
  def __init__(self, in_channels : int, out_channels : int, stride: int = 1 ) -> None:
 
19
 
20
  class UNetBlock(network.ModuleArgsDict):
21
 
22
+ def __init__(self, channels, i : int = 0) -> None:
23
  super().__init__()
24
+ self.add_module("DownConvBlock", ConvBlock(in_channels=channels[0], out_channels=channels[1], stride= 2 if i>0 else 1))
25
 
26
  if len(channels) > 2:
27
+ self.add_module("UNetBlock", UNetBlock(channels[1:], i+1))
28
  self.add_module("UpConvBlock", ConvBlock(in_channels=channels[1]*2, out_channels=channels[1]))
29
 
30
  if i > 0:
31
+ self.add_module("CONV_TRANSPOSE", torch.nn.ConvTranspose3d(in_channels = channels[1], out_channels = channels[0], kernel_size = 2, stride = 2, padding = 0))
32
  self.add_module("SkipConnection", blocks.Concat(), in_branch=[0, 1])
33
 
34
+ class ClipAndNormalize(torch.nn.Module):
35
+
36
+ def __init__(self) -> None:
37
+ super().__init__()
38
+ self.register_buffer("clip_min", torch.empty(1))
39
+ self.register_buffer("clip_max", torch.empty(1))
40
+ self.register_buffer("mean", torch.empty(1))
41
+ self.register_buffer("std", torch.empty(1))
42
+
43
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
44
+ x = torch.clamp(x, self.clip_min, self.clip_max)
45
+ return (x - self.mean) / (self.std)
46
+
47
+ class Unet_TS_CT(network.Network):
48
 
49
  def __init__(self,
50
  optimizer: network.OptimizerLoader = network.OptimizerLoader(),
 
52
  "default:ReduceLROnPlateau": network.LRSchedulersLoader(0)
53
  },
54
  outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
55
+ channels: list[int] = [1, 32, 64, 128, 320]) -> None:
 
56
  super().__init__(
57
  in_channels=channels[0],
58
  optimizer=optimizer,
 
61
  patch=None,
62
  dim=3,
63
  )
64
+ self.add_module("ClipAndNormalize", ClipAndNormalize())
65
+ self.add_module("UNetBlock", UNetBlock(channels))
66
+ self.add_module("Head", UNetHead(channels[1], 118))
67
 
68
  def load(
69
  self,
 
71
  init: bool = True,
72
  ema: bool = False,
73
  ):
74
+ nb_class, in_channels = state_dict["Model"]["Unet_TS_CT"]["Head.Conv.weight"].shape[:2]
75
  self["Head"].add_module("Conv", torch.nn.Conv3d(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0))
76
+ super().load(state_dict, init, ema)
total/M291.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:78065b11af6e339feaa49c72c3aa45d78ae272be1432b27bc17285428a851e6a
3
- size 124807929
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a28e7fe8660329f225f93dbb61feefa0037a5734de87f2a35ab18d2ddfb7601c
3
+ size 124807717
total/M292.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1e0a8d8c17572b392bfdbc480edd169ac46fd8962fdf30b81de2406a6f32f275
3
- size 124808185
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:374ffdb916c267ab5f96ca9f2b754ff17eff9b48762de686e9524e0ade5909f3
3
+ size 124807973
total/M293.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9638deb528c1dafc8edfe5287404eefc764770c913810628cd62b23b5950a4c0
3
- size 124807161
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3104f4cc3c5796468a9d0d68fee602786eeb414a2a28ff2a8bd7f2210604daed
3
+ size 124806949
total/M294.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:11879030dad493ff88d49cbcbbbbee1f5671c79df743ae931b1bd7bbf7302b5e
3
- size 124807801
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8293c863503c8fa858f2eee7f599717b0cc4c16a50b0500107af1e8d905fe669
3
+ size 124807589
total/M295.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4e586a1082dcdb3a054e2ccd670d2506c7436e1f0f6f616cdc771bfe5c41d948
3
- size 124808185
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:01bc78eee961a71a675c7335c88622b1524151982b546113fe1e9a4e3a2f8032
3
+ size 124807973
total/Prediction.yml CHANGED
@@ -1,7 +1,7 @@
1
  Predictor:
2
  Model:
3
- classpath: Model:Unet_TS
4
- Unet_TS:
5
  outputs_criterions: None
6
  channels:
7
  - 1
@@ -11,7 +11,6 @@ Predictor:
11
  - 256
12
  - 320
13
  - 320
14
- mri: false
15
  Dataset:
16
  groups_src:
17
  Volume_0:
@@ -23,18 +22,6 @@ Predictor:
23
  inverse: false
24
  Canonical:
25
  inverse: true
26
- Clip:
27
- min_value: -1024
28
- max_value: 276
29
- save_clip_min: false
30
- save_clip_max: false
31
- mask: None
32
- Standardize:
33
- lazy: false
34
- mean: -370.00039267657144
35
- std: 436.5998675471528
36
- mask: None
37
- inverse: false
38
  ResampleToResolution:
39
  spacing:
40
  - 1.5
 
1
  Predictor:
2
  Model:
3
+ classpath: model:Unet_TS_CT
4
+ Unet_TS_CT:
5
  outputs_criterions: None
6
  channels:
7
  - 1
 
11
  - 256
12
  - 320
13
  - 320
 
14
  Dataset:
15
  groups_src:
16
  Volume_0:
 
22
  inverse: false
23
  Canonical:
24
  inverse: true
 
 
 
 
 
 
 
 
 
 
 
 
25
  ResampleToResolution:
26
  spacing:
27
  - 1.5
total/app.json CHANGED
@@ -1,5 +1,5 @@
1
  {
2
- "display_name": "Total Segmentator",
3
  "short_description": "<b>Description:</b><br>KonfAI-accelerated adaptation of <a href='https://github.com/wasserth/TotalSegmentator'>TotalSegmentator</a>, delivering fast whole-body CT segmentation of <b>118 anatomical structures</b> (1.5 mm resolution) with reduced inference cost compared to the original nnU-Net implementation.<br><br><b>How to cite:</b><br><cite>J. Wasserthal et al., <i>TotalSegmentator: Robust Segmentation of 104 Anatomical Structures in CT Images</i>, Radiology: AI, 2023.</cite>",
4
  "description": "<b>Description:</b><br>This model is an optimized adaptation of the original <a href='https://github.com/wasserth/TotalSegmentator'>TotalSegmentator</a> for the <b>KonfAI</b> deep learning framework.<br><br><b>Capabilities:</b><br>• Segmentation of <b>118 anatomical classes</b> covering organs, bones, muscles, and vessels<br>• Enhanced runtime and memory efficiency vs. the original nnU-Net implementation<br>• High-resolution input: <b>1.5 mm isotropic</b><br><br><b>Training data:</b><br>Trained on a diverse dataset of <b>1204 whole-body CT examinations</b> including different scanners, acquisition settings, contrast phases, and major pathologies (27 organs, 59 bones, 10 muscles, 8 vessels), with manual expert-reviewed annotations<br><br><b>How to cite:</b><br><cite>J. Wasserthal et al., <i>TotalSegmentator: Robust Segmentation of 104 Anatomical Structures in CT Images</i>, Radiology: AI, 2023.</cite>",
5
  "tta": 0,
 
1
  {
2
+ "display_name": "Segmentation: Total Segmentator",
3
  "short_description": "<b>Description:</b><br>KonfAI-accelerated adaptation of <a href='https://github.com/wasserth/TotalSegmentator'>TotalSegmentator</a>, delivering fast whole-body CT segmentation of <b>118 anatomical structures</b> (1.5 mm resolution) with reduced inference cost compared to the original nnU-Net implementation.<br><br><b>How to cite:</b><br><cite>J. Wasserthal et al., <i>TotalSegmentator: Robust Segmentation of 104 Anatomical Structures in CT Images</i>, Radiology: AI, 2023.</cite>",
4
  "description": "<b>Description:</b><br>This model is an optimized adaptation of the original <a href='https://github.com/wasserth/TotalSegmentator'>TotalSegmentator</a> for the <b>KonfAI</b> deep learning framework.<br><br><b>Capabilities:</b><br>• Segmentation of <b>118 anatomical classes</b> covering organs, bones, muscles, and vessels<br>• Enhanced runtime and memory efficiency vs. the original nnU-Net implementation<br>• High-resolution input: <b>1.5 mm isotropic</b><br><br><b>Training data:</b><br>Trained on a diverse dataset of <b>1204 whole-body CT examinations</b> including different scanners, acquisition settings, contrast phases, and major pathologies (27 organs, 59 bones, 10 muscles, 8 vessels), with manual expert-reviewed annotations<br><br><b>How to cite:</b><br><cite>J. Wasserthal et al., <i>TotalSegmentator: Robust Segmentation of 104 Anatomical Structures in CT Images</i>, Radiology: AI, 2023.</cite>",
5
  "tta": 0,
total_mr-3mm/Model.py → total/model.py RENAMED
@@ -1,6 +1,5 @@
1
  import torch
2
  from konfai.network import network, blocks
3
- from konfai.predictor import Reduction
4
 
5
  class ConvBlock(network.ModuleArgsDict):
6
  def __init__(self, in_channels : int, out_channels : int, stride: int = 1 ) -> None:
@@ -20,19 +19,32 @@ class UNetHead(network.ModuleArgsDict):
20
 
21
  class UNetBlock(network.ModuleArgsDict):
22
 
23
- def __init__(self, channels, mri: bool, i : int = 0) -> None:
24
  super().__init__()
25
- self.add_module("DownConvBlock", ConvBlock(in_channels=channels[0], out_channels=channels[1], stride= ((1,2,2) if mri and i > 4 else 2) if i>0 else 1))
26
 
27
  if len(channels) > 2:
28
- self.add_module("UNetBlock", UNetBlock(channels[1:], mri, i+1))
29
  self.add_module("UpConvBlock", ConvBlock(in_channels=channels[1]*2, out_channels=channels[1]))
30
 
31
  if i > 0:
32
- self.add_module("CONV_TRANSPOSE", torch.nn.ConvTranspose3d(in_channels = channels[1], out_channels = channels[0], kernel_size = (1,2,2) if mri and i > 4 else 2, stride = (1,2,2) if mri and i > 4 else 2, padding = 0))
33
  self.add_module("SkipConnection", blocks.Concat(), in_branch=[0, 1])
34
 
35
- class Unet_TS(network.Network):
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  def __init__(self,
38
  optimizer: network.OptimizerLoader = network.OptimizerLoader(),
@@ -40,8 +52,7 @@ class Unet_TS(network.Network):
40
  "default:ReduceLROnPlateau": network.LRSchedulersLoader(0)
41
  },
42
  outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
43
- channels = [1, 32, 64, 128, 320, 320],
44
- mri: bool = False) -> None:
45
  super().__init__(
46
  in_channels=channels[0],
47
  optimizer=optimizer,
@@ -50,8 +61,9 @@ class Unet_TS(network.Network):
50
  patch=None,
51
  dim=3,
52
  )
53
- self.add_module("UNetBlock", UNetBlock(channels, mri))
54
- self.add_module("Head", UNetHead(channels[1], 42))
 
55
 
56
  def load(
57
  self,
@@ -59,6 +71,6 @@ class Unet_TS(network.Network):
59
  init: bool = True,
60
  ema: bool = False,
61
  ):
62
- nb_class, in_channels = state_dict["Model"]["Unet_TS"]["Head.Conv.weight"].shape[:2]
63
  self["Head"].add_module("Conv", torch.nn.Conv3d(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0))
64
- super().load(state_dict, init, ema)
 
1
  import torch
2
  from konfai.network import network, blocks
 
3
 
4
  class ConvBlock(network.ModuleArgsDict):
5
  def __init__(self, in_channels : int, out_channels : int, stride: int = 1 ) -> None:
 
19
 
20
  class UNetBlock(network.ModuleArgsDict):
21
 
22
+ def __init__(self, channels, i : int = 0) -> None:
23
  super().__init__()
24
+ self.add_module("DownConvBlock", ConvBlock(in_channels=channels[0], out_channels=channels[1], stride= 2 if i>0 else 1))
25
 
26
  if len(channels) > 2:
27
+ self.add_module("UNetBlock", UNetBlock(channels[1:], i+1))
28
  self.add_module("UpConvBlock", ConvBlock(in_channels=channels[1]*2, out_channels=channels[1]))
29
 
30
  if i > 0:
31
+ self.add_module("CONV_TRANSPOSE", torch.nn.ConvTranspose3d(in_channels = channels[1], out_channels = channels[0], kernel_size = 2, stride = 2, padding = 0))
32
  self.add_module("SkipConnection", blocks.Concat(), in_branch=[0, 1])
33
 
34
+ class ClipAndNormalize(torch.nn.Module):
35
+
36
+ def __init__(self) -> None:
37
+ super().__init__()
38
+ self.register_buffer("clip_min", torch.empty(1))
39
+ self.register_buffer("clip_max", torch.empty(1))
40
+ self.register_buffer("mean", torch.empty(1))
41
+ self.register_buffer("std", torch.empty(1))
42
+
43
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
44
+ x = torch.clamp(x, self.clip_min, self.clip_max)
45
+ return (x - self.mean) / (self.std)
46
+
47
+ class Unet_TS_CT(network.Network):
48
 
49
  def __init__(self,
50
  optimizer: network.OptimizerLoader = network.OptimizerLoader(),
 
52
  "default:ReduceLROnPlateau": network.LRSchedulersLoader(0)
53
  },
54
  outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
55
+ channels: list[int] = [1, 32, 64, 128, 320, 320]) -> None:
 
56
  super().__init__(
57
  in_channels=channels[0],
58
  optimizer=optimizer,
 
61
  patch=None,
62
  dim=3,
63
  )
64
+ self.add_module("ClipAndNormalize", ClipAndNormalize())
65
+ self.add_module("UNetBlock", UNetBlock(channels))
66
+ self.add_module("Head", UNetHead(channels[1], 118))
67
 
68
  def load(
69
  self,
 
71
  init: bool = True,
72
  ema: bool = False,
73
  ):
74
+ nb_class, in_channels = state_dict["Model"]["Unet_TS_CT"]["Head.Conv.weight"].shape[:2]
75
  self["Head"].add_module("Conv", torch.nn.Conv3d(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0))
76
+ super().load(state_dict, init, ema)
total_mr-3mm/M852.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:22f810c3c77079ca3a196f577d9ab21733aa1e94cb006896c9d35d7d7f726157
3
- size 66216485
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4d74547add966128f251df5d76f61f151800990b9c4ee55aaa60b4646a9a3ea8
3
+ size 66215333
total_mr-3mm/Prediction.yml CHANGED
@@ -22,18 +22,18 @@ Predictor:
22
  inverse: false
23
  Canonical:
24
  inverse: true
25
- Standardize:
26
- lazy: false
27
- mean: None
28
- std: None
29
- mask: None
30
- inverse: false
31
  ResampleToResolution:
32
  spacing:
33
  - 3
34
  - 3
35
  - 3
36
  inverse: true
 
 
 
 
 
 
37
  Padding:
38
  padding:
39
  - 32
 
22
  inverse: false
23
  Canonical:
24
  inverse: true
 
 
 
 
 
 
25
  ResampleToResolution:
26
  spacing:
27
  - 3
28
  - 3
29
  - 3
30
  inverse: true
31
+ Standardize:
32
+ lazy: false
33
+ mean: None
34
+ std: None
35
+ mask: None
36
+ inverse: false
37
  Padding:
38
  padding:
39
  - 32
total_mr-3mm/app.json CHANGED
@@ -1,5 +1,5 @@
1
  {
2
- "display_name": "TotalSegmentator MRI 3mm",
3
  "short_description": "<b>Description:</b><br>Lightweight KonfAI adaptation of <a href='https://github.com/wasserth/TotalSegmentator'>TotalSegmentator MRI</a>, enabling fast multimodal segmentation of <b>50 key anatomical structures</b> from <b>MRI and CT</b> scans at <b>3 mm</b> resolution, greatly reducing memory usage and inference time compared to the original nnU-Net workflow.<br><br><b>How to cite:</b><br><cite>T. Akinci D’Antonoli et al., <i>TotalSegmentator MRI: Robust Sequence-Independent Segmentation of Multiple Anatomic Structures in MRI</i>, Radiology, 2025.</cite>",
4
  "description": "<b>Description:</b><br>This model integrates the reduced-resolution MRI-3mm configuration of TotalSegmentator into the <b>KonfAI</b> accelerated inference pipeline for efficient MRI/CT deployment.<br><br><b>Capabilities:</b><br>• Segmentation of <b>50 essential anatomical structures</b> (major organs, key bones, large vessels)<br>• <b>3 mm</b> isotropic input for high-throughput processing and lower GPU requirements<br>• Robust to acquisition variability including scanner type, contrast, and sequence parameters<br><br><b>Training data:</b><br>Trained on a diverse cohort of <b>1143 clinical scans</b> including <b>616 MRI</b> (multi-site, multi-scanner, multi-sequence) and <b>527 CT</b>, with expert-validated reference masks<br><br>><b>How to cite:</b><br><cite>T. Akinci D’Antonoli et al., <i>TotalSegmentator MRI: Robust Sequence-Independent Segmentation of Multiple Anatomic Structures in MRI</i>, Radiology, 2025.</cite>",
5
  "tta": 0,
 
1
  {
2
+ "display_name": "Segmentation: TotalSegmentator MRI 3mm",
3
  "short_description": "<b>Description:</b><br>Lightweight KonfAI adaptation of <a href='https://github.com/wasserth/TotalSegmentator'>TotalSegmentator MRI</a>, enabling fast multimodal segmentation of <b>50 key anatomical structures</b> from <b>MRI and CT</b> scans at <b>3 mm</b> resolution, greatly reducing memory usage and inference time compared to the original nnU-Net workflow.<br><br><b>How to cite:</b><br><cite>T. Akinci D’Antonoli et al., <i>TotalSegmentator MRI: Robust Sequence-Independent Segmentation of Multiple Anatomic Structures in MRI</i>, Radiology, 2025.</cite>",
4
  "description": "<b>Description:</b><br>This model integrates the reduced-resolution MRI-3mm configuration of TotalSegmentator into the <b>KonfAI</b> accelerated inference pipeline for efficient MRI/CT deployment.<br><br><b>Capabilities:</b><br>• Segmentation of <b>50 essential anatomical structures</b> (major organs, key bones, large vessels)<br>• <b>3 mm</b> isotropic input for high-throughput processing and lower GPU requirements<br>• Robust to acquisition variability including scanner type, contrast, and sequence parameters<br><br><b>Training data:</b><br>Trained on a diverse cohort of <b>1143 clinical scans</b> including <b>616 MRI</b> (multi-site, multi-scanner, multi-sequence) and <b>527 CT</b>, with expert-validated reference masks<br><br>><b>How to cite:</b><br><cite>T. Akinci D’Antonoli et al., <i>TotalSegmentator MRI: Robust Sequence-Independent Segmentation of Multiple Anatomic Structures in MRI</i>, Radiology, 2025.</cite>",
5
  "tta": 0,
total_mr/Model.py → total_mr-3mm/model.py RENAMED
@@ -1,6 +1,5 @@
1
  import torch
2
  from konfai.network import network, blocks
3
- from konfai.predictor import Reduction
4
 
5
  class ConvBlock(network.ModuleArgsDict):
6
  def __init__(self, in_channels : int, out_channels : int, stride: int = 1 ) -> None:
@@ -20,19 +19,19 @@ class UNetHead(network.ModuleArgsDict):
20
 
21
  class UNetBlock(network.ModuleArgsDict):
22
 
23
- def __init__(self, channels, mri: bool, i : int = 0) -> None:
24
  super().__init__()
25
- self.add_module("DownConvBlock", ConvBlock(in_channels=channels[0], out_channels=channels[1], stride= ((1,2,2) if mri and i > 4 else 2) if i>0 else 1))
26
 
27
  if len(channels) > 2:
28
- self.add_module("UNetBlock", UNetBlock(channels[1:], mri, i+1))
29
  self.add_module("UpConvBlock", ConvBlock(in_channels=channels[1]*2, out_channels=channels[1]))
30
 
31
  if i > 0:
32
- self.add_module("CONV_TRANSPOSE", torch.nn.ConvTranspose3d(in_channels = channels[1], out_channels = channels[0], kernel_size = (1,2,2) if mri and i > 4 else 2, stride = (1,2,2) if mri and i > 4 else 2, padding = 0))
33
  self.add_module("SkipConnection", blocks.Concat(), in_branch=[0, 1])
34
 
35
- class Unet_TS(network.Network):
36
 
37
  def __init__(self,
38
  optimizer: network.OptimizerLoader = network.OptimizerLoader(),
@@ -40,8 +39,7 @@ class Unet_TS(network.Network):
40
  "default:ReduceLROnPlateau": network.LRSchedulersLoader(0)
41
  },
42
  outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
43
- channels = [1, 32, 64, 128, 320, 320],
44
- mri: bool = False) -> None:
45
  super().__init__(
46
  in_channels=channels[0],
47
  optimizer=optimizer,
@@ -50,7 +48,7 @@ class Unet_TS(network.Network):
50
  patch=None,
51
  dim=3,
52
  )
53
- self.add_module("UNetBlock", UNetBlock(channels, mri))
54
  self.add_module("Head", UNetHead(channels[1], 42))
55
 
56
  def load(
@@ -59,6 +57,6 @@ class Unet_TS(network.Network):
59
  init: bool = True,
60
  ema: bool = False,
61
  ):
62
- nb_class, in_channels = state_dict["Model"]["Unet_TS"]["Head.Conv.weight"].shape[:2]
63
  self["Head"].add_module("Conv", torch.nn.Conv3d(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0))
64
- super().load(state_dict, init, ema)
 
1
  import torch
2
  from konfai.network import network, blocks
 
3
 
4
  class ConvBlock(network.ModuleArgsDict):
5
  def __init__(self, in_channels : int, out_channels : int, stride: int = 1 ) -> None:
 
19
 
20
  class UNetBlock(network.ModuleArgsDict):
21
 
22
+ def __init__(self, channels, i : int = 0) -> None:
23
  super().__init__()
24
+ self.add_module("DownConvBlock", ConvBlock(in_channels=channels[0], out_channels=channels[1], stride= ((1,2,2) if i > 4 else 2) if i>0 else 1))
25
 
26
  if len(channels) > 2:
27
+ self.add_module("UNetBlock", UNetBlock(channels[1:], i+1))
28
  self.add_module("UpConvBlock", ConvBlock(in_channels=channels[1]*2, out_channels=channels[1]))
29
 
30
  if i > 0:
31
+ self.add_module("CONV_TRANSPOSE", torch.nn.ConvTranspose3d(in_channels = channels[1], out_channels = channels[0], kernel_size = (1,2,2) if i > 4 else 2, stride = (1,2,2) if i > 4 else 2, padding = 0))
32
  self.add_module("SkipConnection", blocks.Concat(), in_branch=[0, 1])
33
 
34
+ class Unet_TS_MR(network.Network):
35
 
36
  def __init__(self,
37
  optimizer: network.OptimizerLoader = network.OptimizerLoader(),
 
39
  "default:ReduceLROnPlateau": network.LRSchedulersLoader(0)
40
  },
41
  outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
42
+ channels: list[int] = [1, 32, 64, 128, 320]) -> None:
 
43
  super().__init__(
44
  in_channels=channels[0],
45
  optimizer=optimizer,
 
48
  patch=None,
49
  dim=3,
50
  )
51
+ self.add_module("UNetBlock", UNetBlock(channels))
52
  self.add_module("Head", UNetHead(channels[1], 42))
53
 
54
  def load(
 
57
  init: bool = True,
58
  ema: bool = False,
59
  ):
60
+ nb_class, in_channels = state_dict["Model"]["Unet_TS_MR"]["Head.Conv.weight"].shape[:2]
61
  self["Head"].add_module("Conv", torch.nn.Conv3d(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0))
62
+ super().load(state_dict, init, ema)
total_mr/M850.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:dbcf0d153ff695b696748e6192a912b305f0b8deba3d04fec1d685ea39df4c37
3
- size 123170169
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc61844cc7053882deae8ca0b2deb89d369528c44f73c935687068968ed39f87
3
+ size 123168825
total_mr/M851.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:570f5d944fd07c26c89dd41cd105dbe470286bc91d01d92e3ba63d9904f03903
3
- size 123169145
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8ced0be1dfb7e48d6e21261ac3969d5d421d6a07b5228af6b2df55cfb6972b73
3
+ size 123167801
total_mr/Prediction.yml CHANGED
@@ -1,7 +1,7 @@
1
  Predictor:
2
  Model:
3
- classpath: Model:Unet_TS
4
- Unet_TS:
5
  outputs_criterions: None
6
  channels:
7
  - 1
@@ -11,7 +11,6 @@ Predictor:
11
  - 256
12
  - 320
13
  - 320
14
- mri: true
15
  Dataset:
16
  groups_src:
17
  Volume_0:
 
1
  Predictor:
2
  Model:
3
+ classpath: model:Unet_TS_MR
4
+ Unet_TS_MR:
5
  outputs_criterions: None
6
  channels:
7
  - 1
 
11
  - 256
12
  - 320
13
  - 320
 
14
  Dataset:
15
  groups_src:
16
  Volume_0:
total_mr/__pycache__/model.cpython-313.pyc ADDED
Binary file (6.25 kB). View file
 
total_mr/app.json CHANGED
@@ -1,5 +1,5 @@
1
  {
2
- "display_name": "TotalSegmentator MRI",
3
  "short_description": "<b>Description:</b><br>KonfAI-accelerated adaptation of <a href='https://github.com/wasserth/TotalSegmentator'>TotalSegmentator MRI</a>, delivering fast multimodal segmentation of <b>80 major anatomical structures</b> from <b>MRI and CT</b> scans, with significantly reduced inference overhead vs. the original nnU-Net workflow.<br><br><b>How to cite:</b><br><cite>T. Akinci D’Antonoli et al., <i>TotalSegmentator MRI: Robust Sequence-Independent Segmentation of Multiple Anatomic Structures in MRI</i>, Radiology, 2025.</cite>",
4
  "description": "<b>Description:</b><br>This model integrates TotalSegmentator MRI into the <b>KonfAI</b> inference framework to accelerate deployment in MRI/CT multimodal workflows.<br><br><b>Capabilities:</b><br>• Automatic segmentation of <b>80 major anatomical structures</b> (organs, vessels, skeleton, digestive system)<br>• Robust to <b>sequence variations</b> across scanners, contrasts, acquisition planes, and sites<br>• High-resolution input: <b>1.5 mm isotropic</b><br><br><b>Training data:</b><br>Trained on a highly diverse clinical dataset of <b>1143 scans</b> including <b>616 MRI</b> (30 scanners, 4 sites, many contrast types) and <b>527 CT</b> scans, with expert-validated manual segmentations <br><br><b>How to cite:</b><br><cite>T. Akinci D’Antonoli et al., <i>TotalSegmentator MRI: Robust Sequence-Independent Segmentation of Multiple Anatomic Structures in MRI</i>, Radiology, 2025.</cite>",
5
  "tta": 0,
 
1
  {
2
+ "display_name": "Segmentation: TotalSegmentator MRI",
3
  "short_description": "<b>Description:</b><br>KonfAI-accelerated adaptation of <a href='https://github.com/wasserth/TotalSegmentator'>TotalSegmentator MRI</a>, delivering fast multimodal segmentation of <b>80 major anatomical structures</b> from <b>MRI and CT</b> scans, with significantly reduced inference overhead vs. the original nnU-Net workflow.<br><br><b>How to cite:</b><br><cite>T. Akinci D’Antonoli et al., <i>TotalSegmentator MRI: Robust Sequence-Independent Segmentation of Multiple Anatomic Structures in MRI</i>, Radiology, 2025.</cite>",
4
  "description": "<b>Description:</b><br>This model integrates TotalSegmentator MRI into the <b>KonfAI</b> inference framework to accelerate deployment in MRI/CT multimodal workflows.<br><br><b>Capabilities:</b><br>• Automatic segmentation of <b>80 major anatomical structures</b> (organs, vessels, skeleton, digestive system)<br>• Robust to <b>sequence variations</b> across scanners, contrasts, acquisition planes, and sites<br>• High-resolution input: <b>1.5 mm isotropic</b><br><br><b>Training data:</b><br>Trained on a highly diverse clinical dataset of <b>1143 scans</b> including <b>616 MRI</b> (30 scanners, 4 sites, many contrast types) and <b>527 CT</b> scans, with expert-validated manual segmentations <br><br><b>How to cite:</b><br><cite>T. Akinci D’Antonoli et al., <i>TotalSegmentator MRI: Robust Sequence-Independent Segmentation of Multiple Anatomic Structures in MRI</i>, Radiology, 2025.</cite>",
5
  "tta": 0,
total-3mm/Model.py → total_mr/model.py RENAMED
@@ -1,6 +1,5 @@
1
  import torch
2
  from konfai.network import network, blocks
3
- from konfai.predictor import Reduction
4
 
5
  class ConvBlock(network.ModuleArgsDict):
6
  def __init__(self, in_channels : int, out_channels : int, stride: int = 1 ) -> None:
@@ -20,19 +19,19 @@ class UNetHead(network.ModuleArgsDict):
20
 
21
  class UNetBlock(network.ModuleArgsDict):
22
 
23
- def __init__(self, channels, mri: bool, i : int = 0) -> None:
24
  super().__init__()
25
- self.add_module("DownConvBlock", ConvBlock(in_channels=channels[0], out_channels=channels[1], stride= ((1,2,2) if mri and i > 4 else 2) if i>0 else 1))
26
 
27
  if len(channels) > 2:
28
- self.add_module("UNetBlock", UNetBlock(channels[1:], mri, i+1))
29
  self.add_module("UpConvBlock", ConvBlock(in_channels=channels[1]*2, out_channels=channels[1]))
30
 
31
  if i > 0:
32
- self.add_module("CONV_TRANSPOSE", torch.nn.ConvTranspose3d(in_channels = channels[1], out_channels = channels[0], kernel_size = (1,2,2) if mri and i > 4 else 2, stride = (1,2,2) if mri and i > 4 else 2, padding = 0))
33
  self.add_module("SkipConnection", blocks.Concat(), in_branch=[0, 1])
34
 
35
- class Unet_TS(network.Network):
36
 
37
  def __init__(self,
38
  optimizer: network.OptimizerLoader = network.OptimizerLoader(),
@@ -40,8 +39,7 @@ class Unet_TS(network.Network):
40
  "default:ReduceLROnPlateau": network.LRSchedulersLoader(0)
41
  },
42
  outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
43
- channels = [1, 32, 64, 128, 320, 320],
44
- mri: bool = False) -> None:
45
  super().__init__(
46
  in_channels=channels[0],
47
  optimizer=optimizer,
@@ -50,7 +48,7 @@ class Unet_TS(network.Network):
50
  patch=None,
51
  dim=3,
52
  )
53
- self.add_module("UNetBlock", UNetBlock(channels, mri))
54
  self.add_module("Head", UNetHead(channels[1], 42))
55
 
56
  def load(
@@ -59,6 +57,6 @@ class Unet_TS(network.Network):
59
  init: bool = True,
60
  ema: bool = False,
61
  ):
62
- nb_class, in_channels = state_dict["Model"]["Unet_TS"]["Head.Conv.weight"].shape[:2]
63
  self["Head"].add_module("Conv", torch.nn.Conv3d(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0))
64
- super().load(state_dict, init, ema)
 
1
  import torch
2
  from konfai.network import network, blocks
 
3
 
4
  class ConvBlock(network.ModuleArgsDict):
5
  def __init__(self, in_channels : int, out_channels : int, stride: int = 1 ) -> None:
 
19
 
20
  class UNetBlock(network.ModuleArgsDict):
21
 
22
+ def __init__(self, channels, i : int = 0) -> None:
23
  super().__init__()
24
+ self.add_module("DownConvBlock", ConvBlock(in_channels=channels[0], out_channels=channels[1], stride= ((1,2,2) if i > 4 else 2) if i>0 else 1))
25
 
26
  if len(channels) > 2:
27
+ self.add_module("UNetBlock", UNetBlock(channels[1:], i+1))
28
  self.add_module("UpConvBlock", ConvBlock(in_channels=channels[1]*2, out_channels=channels[1]))
29
 
30
  if i > 0:
31
+ self.add_module("CONV_TRANSPOSE", torch.nn.ConvTranspose3d(in_channels = channels[1], out_channels = channels[0], kernel_size = (1,2,2) if i > 4 else 2, stride = (1,2,2) if i > 4 else 2, padding = 0))
32
  self.add_module("SkipConnection", blocks.Concat(), in_branch=[0, 1])
33
 
34
+ class Unet_TS_MR(network.Network):
35
 
36
  def __init__(self,
37
  optimizer: network.OptimizerLoader = network.OptimizerLoader(),
 
39
  "default:ReduceLROnPlateau": network.LRSchedulersLoader(0)
40
  },
41
  outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
42
+ channels: list[int] = [1, 32, 64, 128, 320, 320]) -> None:
 
43
  super().__init__(
44
  in_channels=channels[0],
45
  optimizer=optimizer,
 
48
  patch=None,
49
  dim=3,
50
  )
51
+ self.add_module("UNetBlock", UNetBlock(channels))
52
  self.add_module("Head", UNetHead(channels[1], 42))
53
 
54
  def load(
 
57
  init: bool = True,
58
  ema: bool = False,
59
  ):
60
+ nb_class, in_channels = state_dict["Model"]["Unet_TS_MR"]["Head.Conv.weight"].shape[:2]
61
  self["Head"].add_module("Conv", torch.nn.Conv3d(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0))
62
+ super().load(state_dict, init, ema)