File size: 5,726 Bytes
7a52c73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28814c0
 
 
 
 
 
7a52c73
 
 
 
 
28814c0
 
7a52c73
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
import requests
from tqdm import tqdm
import zipfile
import shutil
from pathlib import Path
import os
from functools import partial
from total.model import Unet_TS_CT
from total_mr.model import Unet_TS_MR
import json

def convert_torchScript_full(model_name: str, model: torch.nn.Module, task: str, type: int, mri: bool, url: str) -> None:
    state_dict = download(url, model_name, mri)
    tmp = {}
    with open("Destination_Unet_{}.txt".format(type)) as f2:
        it = iter(state_dict.keys())
        for l1 in f2:
            print(l1)

            key = next(it)
            print(key)
            while "decoder.seg_layers" in key:
                if type == 1:
                    if "decoder.seg_layers.4" in key :
                        break
                if type == 2:
                    if "decoder.seg_layers.3"  in key:
                        break
                if type == 3:
                    if "decoder.seg_layers.2" in key:
                        break
                key = next(it)
                    
            while "all_modules" in key or "decoder.encoder" in key:
                key = next(it)
            tmp[l1.replace("\n", "")] = state_dict[key]
    if not mri:
        tmp["ClipAndNormalize.mean"] = state_dict["mean"]
        tmp["ClipAndNormalize.std"] = state_dict["std"]
        tmp["ClipAndNormalize.clip_min"] = state_dict["percentile_00_5"]
        tmp["ClipAndNormalize.clip_max"] = state_dict["percentile_99_5"]
    state_dict = {"Model" : {model.name : tmp}}
    model.load(state_dict)
    dest_path = Path(f"./{task}")
    dest_path.mkdir(exist_ok=True)
    torch.save(state_dict, str(dest_path/f"{model_name}.pt"))

def download(url: str, model_name: str, mri: bool) -> dict[str, torch.Tensor]:
    with open(url.split("/")[-1], 'wb') as f:
        with requests.get(url, stream=True) as r:
            r.raise_for_status()

            total_size = int(r.headers.get('content-length', 0))
            progress_bar = tqdm(total=total_size, unit='B', unit_scale=True, desc=f"Downloading {model_name}")
            for chunk in r.iter_content(chunk_size=8192 * 16):
                progress_bar.update(len(chunk))
                f.write(chunk)
            progress_bar.close()
    with zipfile.ZipFile(url.split("/")[-1], 'r') as zip_f:
        zip_f.extractall(url.split("/")[-1].replace(".zip", ""))
    os.remove(url.split("/")[-1])
    zip_path = Path(url.split("/")[-1].replace(".zip", ""))

    state_dict = torch.load(next(zip_path.rglob("checkpoint_final.pth"), None), map_location="cpu", weights_only=False)["network_weights"]
    if not mri:
        dataset_fingerprint_path = next(zip_path.rglob("dataset_fingerprint.json"), None)
        with open(dataset_fingerprint_path, "r") as f:
            data = json.load(f)

        ch0 = data["foreground_intensity_properties_per_channel"]["0"]

        state_dict["mean"] = torch.tensor([ch0["mean"]])
        state_dict["std"] = torch.tensor([ch0["std"]])
        state_dict["percentile_00_5"] = torch.tensor([ch0["percentile_00_5"]])
        state_dict["percentile_99_5"] = torch.tensor([ch0["percentile_99_5"]])
    shutil.rmtree(zip_path)
    return state_dict

url = "https://github.com/wasserth/TotalSegmentator/releases/download/"

UnetCPP_1_CT = partial(Unet_TS_CT, channels = [1,32,64,128,256,320,320])
UnetCPP_2_CT = partial(Unet_TS_CT, channels = [1,32,64,128,256,320])
UnetCPP_3_CT = partial(Unet_TS_CT, channels = [1,32,64,128,256])

UnetCPP_1_MR = partial(Unet_TS_MR, channels = [1,32,64,128,256,320,320])
UnetCPP_2_MR = partial(Unet_TS_MR, channels = [1,32,64,128,256,320])
UnetCPP_3_MR = partial(Unet_TS_MR, channels = [1,32,64,128,256])

models = {
         "M291" : (UnetCPP_1_CT(), "total", 1, False, url+"v2.0.0-weights/Dataset291_TotalSegmentator_part1_organs_1559subj.zip"), 
         "M292" : (UnetCPP_1_CT(), "total", 1, False, url+"v2.0.0-weights/Dataset292_TotalSegmentator_part2_vertebrae_1532subj.zip"),
         "M293" : (UnetCPP_1_CT(), "total", 1, False, url+"v2.0.0-weights/Dataset293_TotalSegmentator_part3_cardiac_1559subj.zip"),
         "M294" : (UnetCPP_1_CT(), "total", 1, False, url+"v2.0.0-weights/Dataset294_TotalSegmentator_part4_muscles_1559subj.zip"),
         "M295" : (UnetCPP_1_CT(), "total", 1, False, url+"v2.0.0-weights/Dataset295_TotalSegmentator_part5_ribs_1559subj.zip"),
         "M297" : (UnetCPP_2_CT(), "total-3mm", 2, False, url+"v2.0.4-weights/Dataset297_TotalSegmentator_total_3mm_1559subj_v204.zip"),
         #"M298" : (UnetCPP_2_CT(), 2, False, url+"v2.0.0-weights/Dataset298_TotalSegmentator_total_6mm_1559subj.zip"),
         #"M730" : (UnetCPP_1_MR(), True, 1, url+"v2.2.0-weights/Dataset730_TotalSegmentatorMRI_part1_organs_495subj.zip"),
         #"M731" : (UnetCPP_1_MR(), True, 1, url+"v2.2.0-weights/Dataset731_TotalSegmentatorMRI_part2_muscles_495subj.zip"),
         #"M732" : (UnetCPP_2_MR(), False, 2, url+"v2.2.0-weights/Dataset732_TotalSegmentatorMRI_total_3mm_495subj.zip"),
         #"M733" : (UnetCPP_3_MR(), False, 3, url+"v2.2.0-weights/Dataset733_TotalSegmentatorMRI_total_6mm_495subj.zip"),
         "M850" : (UnetCPP_1_MR(), "total_mr", 1, True, url+"v2.5.0-weights/Dataset850_TotalSegMRI_part1_organs_1088subj.zip"),
         "M851" : (UnetCPP_1_MR(), "total_mr", 1, True, url+"v2.5.0-weights/Dataset851_TotalSegMRI_part2_muscles_1088subj.zip"),
         "M852" : (UnetCPP_2_MR(), "total_mr-3mm", 2, True, url+"v2.5.0-weights/Dataset852_TotalSegMRI_total_3mm_1088subj.zip"),
         #"M853" : (UnetCPP_3_MR(), False, 3, url+"v2.5.0-weights/Dataset853_TotalSegMRI_total_6mm_1088subj.zip")
         }
if __name__ == "__main__":
    for name, model in models.items():
        convert_torchScript_full(name, *model)