VBoussot's picture
Update build.py
28814c0 verified
import torch
import requests
from tqdm import tqdm
import zipfile
import shutil
from pathlib import Path
import os
from functools import partial
from total.model import Unet_TS_CT
from total_mr.model import Unet_TS_MR
import json
def convert_torchScript_full(model_name: str, model: torch.nn.Module, task: str, type: int, mri: bool, url: str) -> None:
state_dict = download(url, model_name, mri)
tmp = {}
with open("Destination_Unet_{}.txt".format(type)) as f2:
it = iter(state_dict.keys())
for l1 in f2:
print(l1)
key = next(it)
print(key)
while "decoder.seg_layers" in key:
if type == 1:
if "decoder.seg_layers.4" in key :
break
if type == 2:
if "decoder.seg_layers.3" in key:
break
if type == 3:
if "decoder.seg_layers.2" in key:
break
key = next(it)
while "all_modules" in key or "decoder.encoder" in key:
key = next(it)
tmp[l1.replace("\n", "")] = state_dict[key]
if not mri:
tmp["ClipAndNormalize.mean"] = state_dict["mean"]
tmp["ClipAndNormalize.std"] = state_dict["std"]
tmp["ClipAndNormalize.clip_min"] = state_dict["percentile_00_5"]
tmp["ClipAndNormalize.clip_max"] = state_dict["percentile_99_5"]
state_dict = {"Model" : {model.name : tmp}}
model.load(state_dict)
dest_path = Path(f"./{task}")
dest_path.mkdir(exist_ok=True)
torch.save(state_dict, str(dest_path/f"{model_name}.pt"))
def download(url: str, model_name: str, mri: bool) -> dict[str, torch.Tensor]:
with open(url.split("/")[-1], 'wb') as f:
with requests.get(url, stream=True) as r:
r.raise_for_status()
total_size = int(r.headers.get('content-length', 0))
progress_bar = tqdm(total=total_size, unit='B', unit_scale=True, desc=f"Downloading {model_name}")
for chunk in r.iter_content(chunk_size=8192 * 16):
progress_bar.update(len(chunk))
f.write(chunk)
progress_bar.close()
with zipfile.ZipFile(url.split("/")[-1], 'r') as zip_f:
zip_f.extractall(url.split("/")[-1].replace(".zip", ""))
os.remove(url.split("/")[-1])
zip_path = Path(url.split("/")[-1].replace(".zip", ""))
state_dict = torch.load(next(zip_path.rglob("checkpoint_final.pth"), None), map_location="cpu", weights_only=False)["network_weights"]
if not mri:
dataset_fingerprint_path = next(zip_path.rglob("dataset_fingerprint.json"), None)
with open(dataset_fingerprint_path, "r") as f:
data = json.load(f)
ch0 = data["foreground_intensity_properties_per_channel"]["0"]
state_dict["mean"] = torch.tensor([ch0["mean"]])
state_dict["std"] = torch.tensor([ch0["std"]])
state_dict["percentile_00_5"] = torch.tensor([ch0["percentile_00_5"]])
state_dict["percentile_99_5"] = torch.tensor([ch0["percentile_99_5"]])
shutil.rmtree(zip_path)
return state_dict
url = "https://github.com/wasserth/TotalSegmentator/releases/download/"
UnetCPP_1_CT = partial(Unet_TS_CT, channels = [1,32,64,128,256,320,320])
UnetCPP_2_CT = partial(Unet_TS_CT, channels = [1,32,64,128,256,320])
UnetCPP_3_CT = partial(Unet_TS_CT, channels = [1,32,64,128,256])
UnetCPP_1_MR = partial(Unet_TS_MR, channels = [1,32,64,128,256,320,320])
UnetCPP_2_MR = partial(Unet_TS_MR, channels = [1,32,64,128,256,320])
UnetCPP_3_MR = partial(Unet_TS_MR, channels = [1,32,64,128,256])
models = {
"M291" : (UnetCPP_1_CT(), "total", 1, False, url+"v2.0.0-weights/Dataset291_TotalSegmentator_part1_organs_1559subj.zip"),
"M292" : (UnetCPP_1_CT(), "total", 1, False, url+"v2.0.0-weights/Dataset292_TotalSegmentator_part2_vertebrae_1532subj.zip"),
"M293" : (UnetCPP_1_CT(), "total", 1, False, url+"v2.0.0-weights/Dataset293_TotalSegmentator_part3_cardiac_1559subj.zip"),
"M294" : (UnetCPP_1_CT(), "total", 1, False, url+"v2.0.0-weights/Dataset294_TotalSegmentator_part4_muscles_1559subj.zip"),
"M295" : (UnetCPP_1_CT(), "total", 1, False, url+"v2.0.0-weights/Dataset295_TotalSegmentator_part5_ribs_1559subj.zip"),
"M297" : (UnetCPP_2_CT(), "total-3mm", 2, False, url+"v2.0.4-weights/Dataset297_TotalSegmentator_total_3mm_1559subj_v204.zip"),
#"M298" : (UnetCPP_2_CT(), 2, False, url+"v2.0.0-weights/Dataset298_TotalSegmentator_total_6mm_1559subj.zip"),
#"M730" : (UnetCPP_1_MR(), True, 1, url+"v2.2.0-weights/Dataset730_TotalSegmentatorMRI_part1_organs_495subj.zip"),
#"M731" : (UnetCPP_1_MR(), True, 1, url+"v2.2.0-weights/Dataset731_TotalSegmentatorMRI_part2_muscles_495subj.zip"),
#"M732" : (UnetCPP_2_MR(), False, 2, url+"v2.2.0-weights/Dataset732_TotalSegmentatorMRI_total_3mm_495subj.zip"),
#"M733" : (UnetCPP_3_MR(), False, 3, url+"v2.2.0-weights/Dataset733_TotalSegmentatorMRI_total_6mm_495subj.zip"),
"M850" : (UnetCPP_1_MR(), "total_mr", 1, True, url+"v2.5.0-weights/Dataset850_TotalSegMRI_part1_organs_1088subj.zip"),
"M851" : (UnetCPP_1_MR(), "total_mr", 1, True, url+"v2.5.0-weights/Dataset851_TotalSegMRI_part2_muscles_1088subj.zip"),
"M852" : (UnetCPP_2_MR(), "total_mr-3mm", 2, True, url+"v2.5.0-weights/Dataset852_TotalSegMRI_total_3mm_1088subj.zip"),
#"M853" : (UnetCPP_3_MR(), False, 3, url+"v2.5.0-weights/Dataset853_TotalSegMRI_total_6mm_1088subj.zip")
}
if __name__ == "__main__":
for name, model in models.items():
convert_torchScript_full(name, *model)