VBoussot's picture
Update build.py
28814c0 verified
raw
history blame
5.73 kB
import torch
import requests
from tqdm import tqdm
import zipfile
import shutil
from pathlib import Path
import os
from functools import partial
from total.model import Unet_TS_CT
from total_mr.model import Unet_TS_MR
import json
def convert_torchScript_full(model_name: str, model: torch.nn.Module, task: str, type: int, mri: bool, url: str) -> None:
state_dict = download(url, model_name, mri)
tmp = {}
with open("Destination_Unet_{}.txt".format(type)) as f2:
it = iter(state_dict.keys())
for l1 in f2:
print(l1)
key = next(it)
print(key)
while "decoder.seg_layers" in key:
if type == 1:
if "decoder.seg_layers.4" in key :
break
if type == 2:
if "decoder.seg_layers.3" in key:
break
if type == 3:
if "decoder.seg_layers.2" in key:
break
key = next(it)
while "all_modules" in key or "decoder.encoder" in key:
key = next(it)
tmp[l1.replace("\n", "")] = state_dict[key]
if not mri:
tmp["ClipAndNormalize.mean"] = state_dict["mean"]
tmp["ClipAndNormalize.std"] = state_dict["std"]
tmp["ClipAndNormalize.clip_min"] = state_dict["percentile_00_5"]
tmp["ClipAndNormalize.clip_max"] = state_dict["percentile_99_5"]
state_dict = {"Model" : {model.name : tmp}}
model.load(state_dict)
dest_path = Path(f"./{task}")
dest_path.mkdir(exist_ok=True)
torch.save(state_dict, str(dest_path/f"{model_name}.pt"))
def download(url: str, model_name: str, mri: bool) -> dict[str, torch.Tensor]:
with open(url.split("/")[-1], 'wb') as f:
with requests.get(url, stream=True) as r:
r.raise_for_status()
total_size = int(r.headers.get('content-length', 0))
progress_bar = tqdm(total=total_size, unit='B', unit_scale=True, desc=f"Downloading {model_name}")
for chunk in r.iter_content(chunk_size=8192 * 16):
progress_bar.update(len(chunk))
f.write(chunk)
progress_bar.close()
with zipfile.ZipFile(url.split("/")[-1], 'r') as zip_f:
zip_f.extractall(url.split("/")[-1].replace(".zip", ""))
os.remove(url.split("/")[-1])
zip_path = Path(url.split("/")[-1].replace(".zip", ""))
state_dict = torch.load(next(zip_path.rglob("checkpoint_final.pth"), None), map_location="cpu", weights_only=False)["network_weights"]
if not mri:
dataset_fingerprint_path = next(zip_path.rglob("dataset_fingerprint.json"), None)
with open(dataset_fingerprint_path, "r") as f:
data = json.load(f)
ch0 = data["foreground_intensity_properties_per_channel"]["0"]
state_dict["mean"] = torch.tensor([ch0["mean"]])
state_dict["std"] = torch.tensor([ch0["std"]])
state_dict["percentile_00_5"] = torch.tensor([ch0["percentile_00_5"]])
state_dict["percentile_99_5"] = torch.tensor([ch0["percentile_99_5"]])
shutil.rmtree(zip_path)
return state_dict
url = "https://github.com/wasserth/TotalSegmentator/releases/download/"
UnetCPP_1_CT = partial(Unet_TS_CT, channels = [1,32,64,128,256,320,320])
UnetCPP_2_CT = partial(Unet_TS_CT, channels = [1,32,64,128,256,320])
UnetCPP_3_CT = partial(Unet_TS_CT, channels = [1,32,64,128,256])
UnetCPP_1_MR = partial(Unet_TS_MR, channels = [1,32,64,128,256,320,320])
UnetCPP_2_MR = partial(Unet_TS_MR, channels = [1,32,64,128,256,320])
UnetCPP_3_MR = partial(Unet_TS_MR, channels = [1,32,64,128,256])
models = {
"M291" : (UnetCPP_1_CT(), "total", 1, False, url+"v2.0.0-weights/Dataset291_TotalSegmentator_part1_organs_1559subj.zip"),
"M292" : (UnetCPP_1_CT(), "total", 1, False, url+"v2.0.0-weights/Dataset292_TotalSegmentator_part2_vertebrae_1532subj.zip"),
"M293" : (UnetCPP_1_CT(), "total", 1, False, url+"v2.0.0-weights/Dataset293_TotalSegmentator_part3_cardiac_1559subj.zip"),
"M294" : (UnetCPP_1_CT(), "total", 1, False, url+"v2.0.0-weights/Dataset294_TotalSegmentator_part4_muscles_1559subj.zip"),
"M295" : (UnetCPP_1_CT(), "total", 1, False, url+"v2.0.0-weights/Dataset295_TotalSegmentator_part5_ribs_1559subj.zip"),
"M297" : (UnetCPP_2_CT(), "total-3mm", 2, False, url+"v2.0.4-weights/Dataset297_TotalSegmentator_total_3mm_1559subj_v204.zip"),
#"M298" : (UnetCPP_2_CT(), 2, False, url+"v2.0.0-weights/Dataset298_TotalSegmentator_total_6mm_1559subj.zip"),
#"M730" : (UnetCPP_1_MR(), True, 1, url+"v2.2.0-weights/Dataset730_TotalSegmentatorMRI_part1_organs_495subj.zip"),
#"M731" : (UnetCPP_1_MR(), True, 1, url+"v2.2.0-weights/Dataset731_TotalSegmentatorMRI_part2_muscles_495subj.zip"),
#"M732" : (UnetCPP_2_MR(), False, 2, url+"v2.2.0-weights/Dataset732_TotalSegmentatorMRI_total_3mm_495subj.zip"),
#"M733" : (UnetCPP_3_MR(), False, 3, url+"v2.2.0-weights/Dataset733_TotalSegmentatorMRI_total_6mm_495subj.zip"),
"M850" : (UnetCPP_1_MR(), "total_mr", 1, True, url+"v2.5.0-weights/Dataset850_TotalSegMRI_part1_organs_1088subj.zip"),
"M851" : (UnetCPP_1_MR(), "total_mr", 1, True, url+"v2.5.0-weights/Dataset851_TotalSegMRI_part2_muscles_1088subj.zip"),
"M852" : (UnetCPP_2_MR(), "total_mr-3mm", 2, True, url+"v2.5.0-weights/Dataset852_TotalSegMRI_total_3mm_1088subj.zip"),
#"M853" : (UnetCPP_3_MR(), False, 3, url+"v2.5.0-weights/Dataset853_TotalSegMRI_total_6mm_1088subj.zip")
}
if __name__ == "__main__":
for name, model in models.items():
convert_torchScript_full(name, *model)