| """
|
| OM_aug_om.py — Augmentation using OMorpher.
|
|
|
| Drop-in replacement for OM_aug.py. Produces identical outputs but uses
|
| OMorpher instead of DeformDDPM + STN + standalone apply_ddf().
|
|
|
| Usage:
|
| python Scripts/OM_aug_om.py -C Config/config_om.yaml
|
| """
|
|
|
| import os
|
| import sys
|
| import argparse
|
|
|
|
|
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
| import numpy as np
|
| import torch
|
| import nibabel as nib
|
| import yaml
|
| from tqdm import tqdm
|
|
|
| import utils
|
| from Dataloader.dataLoader import OminiDataset_inference_w_all
|
| from torch.utils.data import DataLoader
|
| from OMorpher import OMorpher
|
|
|
|
|
|
|
| parser = argparse.ArgumentParser()
|
| parser.add_argument(
|
| "--config", "-C",
|
| help="Path for the config file",
|
| type=str,
|
| default="Config/config_cmr.yaml",
|
| required=False,
|
| )
|
| args = parser.parse_args()
|
|
|
|
|
|
|
| with open(args.config, "r") as file:
|
| hyp_parameters = yaml.safe_load(file)
|
| print(hyp_parameters)
|
|
|
| if not os.path.exists(hyp_parameters["aug_img_savepath"]):
|
| os.makedirs(hyp_parameters["aug_img_savepath"])
|
| if not os.path.exists(hyp_parameters["aug_msk_savepath"]):
|
| os.makedirs(hyp_parameters["aug_msk_savepath"])
|
| if not os.path.exists(hyp_parameters["aug_ddf_savepath"]):
|
| os.makedirs(hyp_parameters["aug_ddf_savepath"])
|
| print(hyp_parameters["aug_img_savepath"])
|
|
|
| hyp_parameters["batchsize"] = 1
|
|
|
|
|
|
|
| select_channels_dict = {}
|
| min_crop_ratio = 0.9
|
|
|
| label_keys = ["heart"]
|
| database = ["MnMs"]
|
| subtype = "es"
|
| hyp_parameters["aug_img_savepath"] = f"Data/Aug_data/mnms_{subtype}/img/"
|
| hyp_parameters["aug_msk_savepath"] = f"Data/Aug_data/mnms_{subtype}/msk/"
|
| hyp_parameters["aug_ddf_savepath"] = f"Data/Aug_data/mnms_{subtype}/ddf/"
|
| select_channels_dict = {"ImgDict": [subtype]}
|
|
|
| dataset = OminiDataset_inference_w_all(
|
| transform=None,
|
| min_crop_ratio=min_crop_ratio,
|
| label_key=label_keys,
|
| database=database,
|
| select_channels_dict=select_channels_dict,
|
| )
|
| Infer_Loader = DataLoader(
|
| dataset,
|
| batch_size=hyp_parameters["batchsize"],
|
| shuffle=False,
|
| )
|
|
|
|
|
|
|
| epoch = f'{hyp_parameters["model_id_str"]}_{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}'
|
| model_save_path = os.path.join(
|
| f'Models/{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}/',
|
| str(epoch) + ".pth",
|
| )
|
| print("Loading model from:", model_save_path)
|
|
|
| om = OMorpher(
|
| config=hyp_parameters,
|
| checkpoint_path=model_save_path,
|
| device=str(hyp_parameters.get("device", "cpu")),
|
| )
|
| print(om)
|
|
|
|
|
|
|
| os.makedirs(hyp_parameters["aug_img_savepath"], exist_ok=True)
|
| os.makedirs(hyp_parameters["aug_msk_savepath"], exist_ok=True)
|
| os.makedirs(hyp_parameters["aug_ddf_savepath"], exist_ok=True)
|
|
|
|
|
|
|
| device = om.device
|
| print("total num of image:", len(Infer_Loader))
|
|
|
| for e, d in tqdm(enumerate(Infer_Loader)):
|
| img = d["img"]
|
| mask = d["labels"]
|
| label_str = str(d["label_channels"])
|
| pid = e
|
|
|
| print("Processing to patient:", pid, " image:", e)
|
|
|
| img = img.type(torch.float32).to(device)
|
| image_original = img.cpu().detach().numpy()
|
|
|
| mask = mask.type(torch.float32).to(device)
|
| mask_original = mask.cpu().detach().numpy()
|
|
|
|
|
| nifti_img = utils.converet_to_nibabel(image_original, ndims=hyp_parameters["ndims"])
|
| nifti_mask = utils.converet_to_nibabel(mask_original, ndims=hyp_parameters["ndims"])
|
|
|
| nib.save(
|
| nifti_img,
|
| os.path.join(
|
| hyp_parameters["aug_img_savepath"],
|
| utils.get_barcode([pid, e]) + ".nii.gz",
|
| ),
|
| )
|
| nib.save(
|
| nifti_mask,
|
| os.path.join(
|
| hyp_parameters["aug_msk_savepath"],
|
| utils.get_barcode([pid, e]) + "_GT.nii.gz",
|
| ),
|
| )
|
|
|
|
|
| noise_step = hyp_parameters["start_noise_step"]
|
| with torch.no_grad():
|
| for im in range(hyp_parameters["aug_coe"]):
|
| print(
|
| f"Generating -> Subject-{pid}, Scan-{e} "
|
| f'({im}/{hyp_parameters["aug_coe"]})',
|
| end="\r",
|
| )
|
|
|
|
|
| om.set_init_img(img)
|
|
|
|
|
| om.set_cond_img(img)
|
|
|
|
|
| t_start = torch.tensor(np.array([noise_step]), device=device)
|
| img_diff, _, ddf_rand = om._get_random_ddf(om._init_img, t_start)
|
|
|
|
|
| msk_diff = om.apply_def(
|
| img=mask, ddf=ddf_rand,
|
| padding_mode="zeros", resample_mode="nearest",
|
| )
|
|
|
|
|
| om.set_init_def(ddf=ddf_rand.clone().detach())
|
|
|
|
|
| om.predict(
|
| T=[noise_step, hyp_parameters["timesteps"]],
|
| proc_type=hyp_parameters["condition_type"],
|
| )
|
|
|
|
|
| ddf_comp = om.get_def()
|
| img_rec = om.apply_def(img=img, ddf=ddf_comp, padding_mode="zeros")
|
| msk_rec = om.apply_def(
|
| img=mask, ddf=ddf_comp,
|
| padding_mode="zeros", resample_mode="nearest",
|
| )
|
|
|
|
|
| denoise_imgs = img_rec.cpu().detach().numpy()
|
| denoise_msks = msk_rec.cpu().detach().numpy()
|
| noisy_imgs_np = img_diff.cpu().detach().numpy()
|
| noisy_msks_np = msk_diff.cpu().detach().numpy()
|
|
|
|
|
| nifti_img_aug = utils.converet_to_nibabel(denoise_imgs, ndims=hyp_parameters["ndims"])
|
| nifti_mask_aug = utils.converet_to_nibabel(denoise_msks, ndims=hyp_parameters["ndims"])
|
| nifti_img = utils.converet_to_nibabel(noisy_imgs_np, ndims=hyp_parameters["ndims"])
|
| nifti_mask = utils.converet_to_nibabel(noisy_msks_np, ndims=hyp_parameters["ndims"])
|
|
|
| nib.save(
|
| nifti_img_aug,
|
| os.path.join(
|
| hyp_parameters["aug_img_savepath"],
|
| utils.get_barcode([pid, e, im, noise_step]) + ".nii.gz",
|
| ),
|
| )
|
| nib.save(
|
| nifti_mask_aug,
|
| os.path.join(
|
| hyp_parameters["aug_msk_savepath"],
|
| utils.get_barcode([pid, e, im, noise_step]) + "_GT.nii.gz",
|
| ),
|
| )
|
|
|
|
|
| nib.save(
|
| nifti_img,
|
| os.path.join(
|
| hyp_parameters["aug_img_savepath"],
|
| utils.get_barcode(
|
| [pid, e, im, noise_step],
|
| header=["Patient", "Slice", "NoiseImg", "NoiseStep"],
|
| ) + ".nii.gz",
|
| ),
|
| )
|
| nib.save(
|
| nifti_mask,
|
| os.path.join(
|
| hyp_parameters["aug_msk_savepath"],
|
| utils.get_barcode(
|
| [pid, e, im, noise_step],
|
| header=["Patient", "Slice", "NoiseImg", "NoiseStep"],
|
| ) + "_GT.nii.gz",
|
| ),
|
| )
|
|
|
| if (im - hyp_parameters["start_noise_step"]) % 2 == 0:
|
| noise_step = noise_step + hyp_parameters["noise_step"]
|
|
|
| if e >= 0:
|
| exit()
|
|
|