Omini3D / OM_reg.py
maxmo2009's picture
Sync from local: code + epoch-110 checkpoint, clean README
2af0e94 verified
import torch
import torchvision
from torch import nn
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torch.optim import Adam
from torchvision.utils import make_grid
from Diffusion.diffuser import DeformDDPM
from Diffusion.networks import get_net, STN
from torchvision.transforms import Lambda
import random
import os
import utils
from Dataloader.dataloader0 import get_dataloader
from Dataloader.dataLoader import *
from torchvision.utils import save_image
from einops import rearrange, reduce, repeat
# import matplotlib.image
import numpy as np
import nibabel as nib
from tqdm import tqdm
import yaml
import argparse
EPS = 10e-8
parser = argparse.ArgumentParser()
parser.add_argument(
"--config",
"-C",
help="Path for the config file",
type=str,
default="Config/config_cmr.yaml",
# default="Config/config_lct.yaml",
required=False,
)
args = parser.parse_args()
#=======================================================================================================================
# config_path = 'Config/config_cmr.yaml'
# config_path = 'Config/config_lct.yaml'
# Load the YAML file into a dictionary
with open(args.config, 'r') as file:
hyp_parameters = yaml.safe_load(file)
print(hyp_parameters)
# hyp_parameters["aug_img_savepath"] = os.path.join(hyp_parameters["aug_img_savepath"],hyp_parameters["data_name"],'')
if not os.path.exists(hyp_parameters["aug_img_savepath"]):
os.makedirs(hyp_parameters["aug_img_savepath"])
if not os.path.exists(hyp_parameters["aug_msk_savepath"]):
os.makedirs(hyp_parameters["aug_msk_savepath"])
if not os.path.exists(hyp_parameters["aug_ddf_savepath"]):
os.makedirs(hyp_parameters["aug_ddf_savepath"])
print(hyp_parameters["aug_img_savepath"])
hyp_parameters['batchsize'] = 1
# =======================================================================================================================
# min_crop_ratio = 0.5
min_crop_ratio = 0.9
# dataset = OminiDataset_v1(transform=None,min_crop_ratio=min_crop_ratio)
# Infer_Loader = DataLoader(
# dataset,
# batch_size=hyp_parameters['batchsize'],
# shuffle=False
# )
# label_keys = ['heart']
label_keys = ['brain']
# label_keys = ['pancreas']
# database = ['MSD']
database = ['Brats2019']
dataset = OminiDataset_inference_w_all(transform=None,min_crop_ratio=min_crop_ratio,label_key = label_keys, database=database)
Infer_Loader = DataLoader(
dataset,
batch_size=hyp_parameters['batchsize'],
shuffle=False
)
# =======================================================================================================================
# Data_Loader=get_dataloader(hyp_parameters['data_name'],mode='aug')
# transformer = utils.get_transformer(img_sz=hyp_parameters["ndims"]*[hyp_parameters['img_size']])
# dataset = Data_Loader(patient_index = hyp_parameters["patients_list"])
# train_loader = DataLoader(dataset, batch_size = hyp_parameters['batchsize'], shuffle = False)
epoch=f'{hyp_parameters["model_id_str"]}_{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}'
model_save_path = f'Models/{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}/'
model_save_path = os.path.join(model_save_path, str(epoch)+'.pth')
Net = get_net(hyp_parameters["net_name"])
Deformddpm = DeformDDPM(
network=Net(n_steps = hyp_parameters["timesteps"],
ndims = hyp_parameters["ndims"],
num_input_chn = hyp_parameters["num_input_chn"],
res = hyp_parameters['img_size']
),
n_steps = hyp_parameters["timesteps"],
image_chw = [hyp_parameters["num_input_chn"]] + [hyp_parameters["img_size"]]*hyp_parameters["ndims"],
device = hyp_parameters["device"],
batch_size = hyp_parameters["batchsize"],
img_pad_mode = hyp_parameters["img_pad_mode"],
ddf_pad_mode = hyp_parameters["ddf_pad_mode"],
padding_mode = hyp_parameters["padding_mode"],
v_scale = hyp_parameters["v_scale"],
resample_mode = hyp_parameters["resample_mode"],
inf_mode = True, # set to True for inference, which will use fixed slice num and slice idx for better evaluation
)
Deformddpm.to(hyp_parameters["device"])
ddf_stn = STN(
img_sz = hyp_parameters["img_size"],
ndims = hyp_parameters["ndims"],
padding_mode = hyp_parameters['padding_mode'],
device = hyp_parameters["device"],
)
ddf_stn.to(hyp_parameters["device"])
print("Loading model from:", model_save_path)
# Deformddpm.load_state_dict(torch.load(model_save_path))
checkpoint = torch.load(model_save_path, map_location='cpu')
Deformddpm.load_state_dict(checkpoint['model_state_dict'])
Deformddpm.eval()
os.makedirs(hyp_parameters['reg_img_savepath'], exist_ok=True)
os.makedirs(hyp_parameters['reg_msk_savepath'], exist_ok=True)
os.makedirs(hyp_parameters['reg_ddf_savepath'], exist_ok=True)
print("total num of image:", len(Infer_Loader))
for e, d in tqdm(enumerate(Infer_Loader)):
# for e, d in enumerate(Infer_Loader):
# img, mask, pid = d
# img = d
# mask = d
img = d['img']
mask = d['labels']
# pid = pid.cpu().detach().numpy()
# pid = pid[0]
pid = e
print('Processing to patient:', pid, ' image:',e)
img = img.to(hyp_parameters["device"])
img = img.type(torch.float32)
image_original = img.cpu().detach().numpy()
#
#
if e <= 0:
target_img = img.clone().detach() # save the first image as target image for conditioning
mask = mask.to(hyp_parameters["device"])
mask = mask.type(torch.float32)
mask_original = mask.cpu().detach().numpy()
# print(pid, image_original.shape, mask_original.max())
nifti_img = utils.converet_to_nibabel(image_original, ndims=hyp_parameters["ndims"])
nifti_mask = utils.converet_to_nibabel(mask_original, ndims=hyp_parameters["ndims"])
# Saving original (undeformed image)
# CMR: format: Patient0001_Slice0001_ORG_NA.nii.gz
# Lung CT: Patient0001_Slice0001_ORG_NA.nii.gz
nib.save(nifti_img, os.path.join(hyp_parameters['reg_img_savepath'],utils.get_barcode([pid,e])+'.nii.gz'))
# Saving original (undeformed image)
# CMR: format: Patient0001_Slice0001_ORG_NA_GT.nii.gz
# Lung CT: ...
nib.save(nifti_img, os.path.join(hyp_parameters['reg_msk_savepath'],utils.get_barcode([pid,e])+'_GT.nii.gz'))
noise_step = hyp_parameters["start_noise_step"]
with torch.no_grad():
for im in range(1):
# # Permute
# if hyp_parameters["ndims"] == 2:
# [img, mask] = utils.random_permute([img, mask], select_dims=[-1, -2]) # add random rotation to image
# elif hyp_parameters["ndims"] == 3:
# [img, mask] = utils.random_permute([img, mask], select_dims=[-1, -2, -3]) # add random rotation to image
print('Generating - >', 'Subject-',pid,', Scan-',e,' (',im,'/',hyp_parameters["aug_coe"],')', end='\r')
[ddf_comp,ddf_rand],[img_rec,img_diff,img_save],[msk_rec,msk_diff,msk_save] = Deformddpm.diff_recover(img_org=img,cond_imgs=target_img.clone().detach(),msk_org=mask,T=[None,hyp_parameters["timesteps"]],v_scale=hyp_parameters["v_scale"],t_save=None,proc_type=hyp_parameters["condition_type"])
denoise_imgs = img_rec.cpu().detach().numpy()
denoise_msks = msk_rec.cpu().detach().numpy()
noisy_imgs_np = img_diff.cpu().detach().numpy()
noisy_msks_np = msk_diff.cpu().detach().numpy()
nifti_img_aug = utils.converet_to_nibabel(denoise_imgs, ndims=hyp_parameters["ndims"])
nifti_mask_aug = utils.converet_to_nibabel(denoise_msks, ndims=hyp_parameters["ndims"])
nifti_img = utils.converet_to_nibabel(noisy_imgs_np, ndims=hyp_parameters["ndims"])
nifti_mask = utils.converet_to_nibabel(noisy_msks_np, ndims=hyp_parameters["ndims"])
nib.save(nifti_img_aug, os.path.join(hyp_parameters['reg_img_savepath'],utils.get_barcode([pid,e,im,noise_step])+'.nii.gz'))
nib.save(nifti_mask_aug, os.path.join(hyp_parameters['reg_msk_savepath'],utils.get_barcode([pid,e,im,noise_step])+'_GT.nii.gz'))
# Saving noisy image to nifti
# CMR: format: Patient0001_Slice0001_NosieImg0001_NoiseStep0070.nii.gz
# Lung CT: ...
nib.save(nifti_img, os.path.join(hyp_parameters['reg_img_savepath'],utils.get_barcode([pid,e,im,noise_step],header=['Patient','Slice','NoiseImg','NoiseStep'])+'.nii.gz'))
nib.save(nifti_mask, os.path.join(hyp_parameters['reg_msk_savepath'],utils.get_barcode([pid,e,im,noise_step],header=['Patient','Slice','NoiseImg','NoiseStep'])+'_GT.nii.gz'))
if (im - hyp_parameters["start_noise_step"])%2 == 0:
noise_step = noise_step + hyp_parameters["noise_step"]
# break # for testing
if e > 5:
break