| import torch
|
| import torchvision
|
| from torch import nn
|
| from torchvision.utils import save_image
|
| from torch.utils.data import DataLoader
|
| from torch.optim import Adam
|
| from torchvision.utils import make_grid
|
| from Diffusion.diffuser import DeformDDPM
|
| from Diffusion.networks import get_net, STN
|
| from torchvision.transforms import Lambda
|
| import random
|
| import os
|
| import utils
|
| from Dataloader.dataloader0 import get_dataloader
|
| from Dataloader.dataLoader import *
|
|
|
| from torchvision.utils import save_image
|
| from einops import rearrange, reduce, repeat
|
|
|
| import numpy as np
|
| import nibabel as nib
|
| from tqdm import tqdm
|
| import yaml
|
| import argparse
|
|
|
| EPS = 10e-8
|
|
|
| parser = argparse.ArgumentParser()
|
|
|
| parser.add_argument(
|
| "--config",
|
| "-C",
|
| help="Path for the config file",
|
| type=str,
|
| default="Config/config_cmr.yaml",
|
|
|
| required=False,
|
| )
|
| args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
|
| with open(args.config, 'r') as file:
|
| hyp_parameters = yaml.safe_load(file)
|
| print(hyp_parameters)
|
|
|
| if not os.path.exists(hyp_parameters["aug_img_savepath"]):
|
| os.makedirs(hyp_parameters["aug_img_savepath"])
|
| if not os.path.exists(hyp_parameters["aug_msk_savepath"]):
|
| os.makedirs(hyp_parameters["aug_msk_savepath"])
|
| if not os.path.exists(hyp_parameters["aug_ddf_savepath"]):
|
| os.makedirs(hyp_parameters["aug_ddf_savepath"])
|
| print(hyp_parameters["aug_img_savepath"])
|
|
|
| hyp_parameters['batchsize'] = 1
|
|
|
|
|
|
|
| select_channels_dict={}
|
|
|
| min_crop_ratio = 0.9
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| label_keys = ['heart']
|
| database = ['MnMs']
|
|
|
| subtype = "es"
|
| hyp_parameters["aug_img_savepath"]=f"Data/Aug_data/mnms_{subtype}/img/"
|
| hyp_parameters["aug_msk_savepath"]=f"Data/Aug_data/mnms_{subtype}/msk/"
|
| hyp_parameters["aug_ddf_savepath"]=f"Data/Aug_data/mnms_{subtype}/ddf/"
|
| select_channels_dict={
|
| "ImgDict":[subtype]
|
| }
|
|
|
|
|
| dataset = OminiDataset_inference_w_all(transform=None,min_crop_ratio=min_crop_ratio,label_key = label_keys, database=database, select_channels_dict=select_channels_dict)
|
| Infer_Loader = DataLoader(
|
| dataset,
|
| batch_size=hyp_parameters['batchsize'],
|
| shuffle=False
|
| )
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| epoch=f'{hyp_parameters["model_id_str"]}_{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}'
|
| model_save_path = f'Models/{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}/'
|
| model_save_path = os.path.join(model_save_path, str(epoch)+'.pth')
|
|
|
|
|
|
|
| Net = get_net(hyp_parameters["net_name"])
|
|
|
| Deformddpm = DeformDDPM(
|
| network=Net(n_steps = hyp_parameters["timesteps"],
|
| ndims = hyp_parameters["ndims"],
|
| num_input_chn = hyp_parameters["num_input_chn"],
|
| res = hyp_parameters['img_size']
|
| ),
|
| n_steps = hyp_parameters["timesteps"],
|
| image_chw = [hyp_parameters["num_input_chn"]] + [hyp_parameters["img_size"]]*hyp_parameters["ndims"],
|
| device = hyp_parameters["device"],
|
| batch_size = hyp_parameters["batchsize"],
|
| img_pad_mode = hyp_parameters["img_pad_mode"],
|
| ddf_pad_mode = hyp_parameters["ddf_pad_mode"],
|
| padding_mode = hyp_parameters["padding_mode"],
|
| v_scale = hyp_parameters["v_scale"],
|
| resample_mode = hyp_parameters["resample_mode"],
|
| )
|
| Deformddpm.to(hyp_parameters["device"])
|
|
|
| ddf_stn = STN(
|
| img_sz = hyp_parameters["img_size"],
|
| ndims = hyp_parameters["ndims"],
|
| padding_mode = hyp_parameters['padding_mode'],
|
| device = hyp_parameters["device"],
|
| )
|
| ddf_stn.to(hyp_parameters["device"])
|
|
|
| print("Loading model from:", model_save_path)
|
|
|
| checkpoint = torch.load(model_save_path)
|
| Deformddpm.load_state_dict(checkpoint['model_state_dict'])
|
| Deformddpm.eval()
|
|
|
| os.makedirs(hyp_parameters['aug_img_savepath'], exist_ok=True)
|
| os.makedirs(hyp_parameters['aug_msk_savepath'], exist_ok=True)
|
| os.makedirs(hyp_parameters['aug_ddf_savepath'], exist_ok=True)
|
|
|
| print("total num of image:", len(Infer_Loader))
|
| for e, d in tqdm(enumerate(Infer_Loader)):
|
|
|
|
|
|
|
|
|
|
|
| img = d['img']
|
| mask = d['labels']
|
| label_str = str(d['label_channels'])
|
|
|
|
|
|
|
|
|
|
|
| pid = e
|
|
|
| print('Processing to patient:', pid, ' image:',e)
|
|
|
|
|
| img = img.type(torch.float32)
|
| img = img.to(hyp_parameters["device"])
|
| image_original = img.cpu().detach().numpy()
|
|
|
| mask = mask.type(torch.float32)
|
| mask = mask.to(hyp_parameters["device"])
|
| mask_original = mask.cpu().detach().numpy()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| nifti_img = utils.converet_to_nibabel(image_original,ndims=hyp_parameters["ndims"])
|
| nifti_mask = utils.converet_to_nibabel(mask_original,ndims=hyp_parameters["ndims"])
|
|
|
|
|
|
|
|
|
| nib.save(nifti_img, os.path.join(hyp_parameters['aug_img_savepath'],utils.get_barcode([pid,e])+'.nii.gz'))
|
|
|
|
|
|
|
|
|
| nib.save(nifti_mask, os.path.join(hyp_parameters['aug_msk_savepath'],utils.get_barcode([pid,e])+'_GT.nii.gz'))
|
|
|
|
|
| noise_step = hyp_parameters["start_noise_step"]
|
| with torch.no_grad():
|
| for im in range(hyp_parameters["aug_coe"]):
|
|
|
|
|
|
|
|
|
|
|
|
|
| print('Generating - >', 'Subject-',pid,', Scan-',e,' (',im,'/',hyp_parameters["aug_coe"],')', end='\r')
|
|
|
| [ddf_comp,ddf_rand],[img_rec,img_diff,img_save],[msk_rec,msk_diff,msk_save] = Deformddpm.diff_recover(img_org=img,msk_org=mask,T=[noise_step,hyp_parameters["timesteps"]],v_scale=hyp_parameters["v_scale"],t_save=None,proc_type=hyp_parameters["condition_type"])
|
|
|
| denoise_imgs = img_rec.cpu().detach().numpy()
|
| denoise_msks = msk_rec.cpu().detach().numpy()
|
| noisy_imgs_np = img_diff.cpu().detach().numpy()
|
| noisy_msks_np = msk_diff.cpu().detach().numpy()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| nifti_img_aug = utils.converet_to_nibabel(denoise_imgs,ndims=hyp_parameters["ndims"])
|
| nifti_mask_aug = utils.converet_to_nibabel(denoise_msks,ndims=hyp_parameters["ndims"])
|
| nifti_img = utils.converet_to_nibabel(noisy_imgs_np,ndims=hyp_parameters["ndims"])
|
| nifti_mask = utils.converet_to_nibabel(noisy_msks_np,ndims=hyp_parameters["ndims"])
|
|
|
| nib.save(nifti_img_aug, os.path.join(hyp_parameters['aug_img_savepath'],utils.get_barcode([pid,e,im,noise_step])+'.nii.gz'))
|
| nib.save(nifti_mask_aug, os.path.join(hyp_parameters['aug_msk_savepath'],utils.get_barcode([pid,e,im,noise_step])+'_GT.nii.gz'))
|
|
|
|
|
|
|
|
|
| nib.save(nifti_img, os.path.join(hyp_parameters['aug_img_savepath'],utils.get_barcode([pid,e,im,noise_step],header=['Patient','Slice','NoiseImg','NoiseStep'])+'.nii.gz'))
|
| nib.save(nifti_mask, os.path.join(hyp_parameters['aug_msk_savepath'],utils.get_barcode([pid,e,im,noise_step],header=['Patient','Slice','NoiseImg','NoiseStep'])+'_GT.nii.gz'))
|
|
|
|
|
| if (im - hyp_parameters["start_noise_step"])%2 == 0:
|
| noise_step = noise_step + hyp_parameters["noise_step"]
|
|
|
| if e >= 0:
|
| exit()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|