| import torch |
| import requests |
| from tqdm import tqdm |
| import zipfile |
| import shutil |
| from pathlib import Path |
| import os |
| from functools import partial |
| from total.model import Unet_TS_CT |
| from total_mr.model import Unet_TS_MR |
| import json |
|
|
| def convert_torchScript_full(model_name: str, model: torch.nn.Module, task: str, type: int, mri: bool, url: str) -> None: |
| state_dict = download(url, model_name, mri) |
| tmp = {} |
| with open("Destination_Unet_{}.txt".format(type)) as f2: |
| it = iter(state_dict.keys()) |
| for l1 in f2: |
| print(l1) |
|
|
| key = next(it) |
| print(key) |
| while "decoder.seg_layers" in key: |
| if type == 1: |
| if "decoder.seg_layers.4" in key : |
| break |
| if type == 2: |
| if "decoder.seg_layers.3" in key: |
| break |
| if type == 3: |
| if "decoder.seg_layers.2" in key: |
| break |
| key = next(it) |
| |
| while "all_modules" in key or "decoder.encoder" in key: |
| key = next(it) |
| tmp[l1.replace("\n", "")] = state_dict[key] |
| if not mri: |
| tmp["ClipAndNormalize.mean"] = state_dict["mean"] |
| tmp["ClipAndNormalize.std"] = state_dict["std"] |
| tmp["ClipAndNormalize.clip_min"] = state_dict["percentile_00_5"] |
| tmp["ClipAndNormalize.clip_max"] = state_dict["percentile_99_5"] |
| state_dict = {"Model" : {model.name : tmp}} |
| model.load(state_dict) |
| dest_path = Path(f"./{task}") |
| dest_path.mkdir(exist_ok=True) |
| torch.save(state_dict, str(dest_path/f"{model_name}.pt")) |
|
|
| def download(url: str, model_name: str, mri: bool) -> dict[str, torch.Tensor]: |
| with open(url.split("/")[-1], 'wb') as f: |
| with requests.get(url, stream=True) as r: |
| r.raise_for_status() |
|
|
| total_size = int(r.headers.get('content-length', 0)) |
| progress_bar = tqdm(total=total_size, unit='B', unit_scale=True, desc=f"Downloading {model_name}") |
| for chunk in r.iter_content(chunk_size=8192 * 16): |
| progress_bar.update(len(chunk)) |
| f.write(chunk) |
| progress_bar.close() |
| with zipfile.ZipFile(url.split("/")[-1], 'r') as zip_f: |
| zip_f.extractall(url.split("/")[-1].replace(".zip", "")) |
| os.remove(url.split("/")[-1]) |
| zip_path = Path(url.split("/")[-1].replace(".zip", "")) |
|
|
| state_dict = torch.load(next(zip_path.rglob("checkpoint_final.pth"), None), map_location="cpu", weights_only=False)["network_weights"] |
| if not mri: |
| dataset_fingerprint_path = next(zip_path.rglob("dataset_fingerprint.json"), None) |
| with open(dataset_fingerprint_path, "r") as f: |
| data = json.load(f) |
|
|
| ch0 = data["foreground_intensity_properties_per_channel"]["0"] |
|
|
| state_dict["mean"] = torch.tensor([ch0["mean"]]) |
| state_dict["std"] = torch.tensor([ch0["std"]]) |
| state_dict["percentile_00_5"] = torch.tensor([ch0["percentile_00_5"]]) |
| state_dict["percentile_99_5"] = torch.tensor([ch0["percentile_99_5"]]) |
| shutil.rmtree(zip_path) |
| return state_dict |
|
|
| url = "https://github.com/wasserth/TotalSegmentator/releases/download/" |
|
|
| UnetCPP_1_CT = partial(Unet_TS_CT, channels = [1,32,64,128,256,320,320]) |
| UnetCPP_2_CT = partial(Unet_TS_CT, channels = [1,32,64,128,256,320]) |
| UnetCPP_3_CT = partial(Unet_TS_CT, channels = [1,32,64,128,256]) |
|
|
| UnetCPP_1_MR = partial(Unet_TS_MR, channels = [1,32,64,128,256,320,320]) |
| UnetCPP_2_MR = partial(Unet_TS_MR, channels = [1,32,64,128,256,320]) |
| UnetCPP_3_MR = partial(Unet_TS_MR, channels = [1,32,64,128,256]) |
|
|
| models = { |
| "M291" : (UnetCPP_1_CT(), "total", 1, False, url+"v2.0.0-weights/Dataset291_TotalSegmentator_part1_organs_1559subj.zip"), |
| "M292" : (UnetCPP_1_CT(), "total", 1, False, url+"v2.0.0-weights/Dataset292_TotalSegmentator_part2_vertebrae_1532subj.zip"), |
| "M293" : (UnetCPP_1_CT(), "total", 1, False, url+"v2.0.0-weights/Dataset293_TotalSegmentator_part3_cardiac_1559subj.zip"), |
| "M294" : (UnetCPP_1_CT(), "total", 1, False, url+"v2.0.0-weights/Dataset294_TotalSegmentator_part4_muscles_1559subj.zip"), |
| "M295" : (UnetCPP_1_CT(), "total", 1, False, url+"v2.0.0-weights/Dataset295_TotalSegmentator_part5_ribs_1559subj.zip"), |
| "M297" : (UnetCPP_2_CT(), "total-3mm", 2, False, url+"v2.0.4-weights/Dataset297_TotalSegmentator_total_3mm_1559subj_v204.zip"), |
| |
| |
| |
| |
| |
| "M850" : (UnetCPP_1_MR(), "total_mr", 1, True, url+"v2.5.0-weights/Dataset850_TotalSegMRI_part1_organs_1088subj.zip"), |
| "M851" : (UnetCPP_1_MR(), "total_mr", 1, True, url+"v2.5.0-weights/Dataset851_TotalSegMRI_part2_muscles_1088subj.zip"), |
| "M852" : (UnetCPP_2_MR(), "total_mr-3mm", 2, True, url+"v2.5.0-weights/Dataset852_TotalSegMRI_total_3mm_1088subj.zip"), |
| |
| } |
| if __name__ == "__main__": |
| for name, model in models.items(): |
| convert_torchScript_full(name, *model) |