| import torch
|
| import torchvision
|
| from torch import nn
|
| from torchvision.utils import save_image
|
| from torch.utils.data import DataLoader
|
| from torch.optim import Adam
|
| from torchvision.utils import make_grid
|
| from Diffusion.diffuser import DeformDDPM
|
| from Diffusion.networks import get_net, STN
|
| from torchvision.transforms import Lambda
|
| import random
|
| import os
|
| import utils
|
| from Dataloader.dataloader0 import get_dataloader
|
| from Dataloader.dataLoader import *
|
|
|
| from torchvision.utils import save_image
|
| from einops import rearrange, reduce, repeat
|
| import numpy as np
|
| import nibabel as nib
|
| from tqdm import tqdm
|
| import yaml
|
| import argparse
|
| import torch.nn.functional as F
|
| import SimpleITK as sitk
|
| from skimage.transform import resize
|
|
|
| EPS = 10e-8
|
|
|
| parser = argparse.ArgumentParser()
|
|
|
| parser.add_argument(
|
| "--config",
|
| "-C",
|
| help="Path for the config file",
|
| type=str,
|
| default="Config/config_om.yaml",
|
| required=False,
|
| )
|
| args = parser.parse_args()
|
|
|
|
|
|
|
| with open(args.config, 'r') as file:
|
| hyp_parameters = yaml.safe_load(file)
|
| print(hyp_parameters)
|
|
|
| if not os.path.exists(hyp_parameters["aug_img_savepath"]):
|
| os.makedirs(hyp_parameters["aug_img_savepath"])
|
| if not os.path.exists(hyp_parameters["aug_msk_savepath"]):
|
| os.makedirs(hyp_parameters["aug_msk_savepath"])
|
| if not os.path.exists(hyp_parameters["aug_ddf_savepath"]):
|
| os.makedirs(hyp_parameters["aug_ddf_savepath"])
|
| print(hyp_parameters["aug_img_savepath"])
|
|
|
| hyp_parameters['batchsize'] = 1
|
| model_img_sz = hyp_parameters['img_size']
|
|
|
|
|
|
|
|
|
|
|
| label_keys = ['brain']
|
| database = ['Brats2019']
|
|
|
| dataset = OminiDataset_inference_w_all(
|
| transform=None, min_crop_ratio=1.0, label_key=label_keys, database=database)
|
|
|
|
|
|
|
| epoch=f'{hyp_parameters["model_id_str"]}_{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}'
|
| model_save_path = f'Models/{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}/'
|
| model_save_path = os.path.join(model_save_path, str(epoch)+'.pth')
|
|
|
|
|
| Net = get_net(hyp_parameters["net_name"])
|
|
|
| Deformddpm = DeformDDPM(
|
| network=Net(n_steps = hyp_parameters["timesteps"],
|
| ndims = hyp_parameters["ndims"],
|
| num_input_chn = hyp_parameters["num_input_chn"],
|
| res = model_img_sz
|
| ),
|
| n_steps = hyp_parameters["timesteps"],
|
| image_chw = [hyp_parameters["num_input_chn"]] + [model_img_sz]*hyp_parameters["ndims"],
|
| device = hyp_parameters["device"],
|
| batch_size = hyp_parameters["batchsize"],
|
| img_pad_mode = hyp_parameters["img_pad_mode"],
|
| ddf_pad_mode = hyp_parameters["ddf_pad_mode"],
|
| padding_mode = hyp_parameters["padding_mode"],
|
| v_scale = hyp_parameters["v_scale"],
|
| resample_mode = hyp_parameters["resample_mode"],
|
| inf_mode = True,
|
| )
|
| Deformddpm.to(hyp_parameters["device"])
|
|
|
| ddf_stn = STN(
|
| img_sz = model_img_sz,
|
| ndims = hyp_parameters["ndims"],
|
| padding_mode = hyp_parameters['padding_mode'],
|
| device = hyp_parameters["device"],
|
| )
|
| ddf_stn.to(hyp_parameters["device"])
|
|
|
| print("Loading model from:", model_save_path)
|
| checkpoint = torch.load(model_save_path, map_location='cpu')
|
| Deformddpm.load_state_dict(checkpoint['model_state_dict'])
|
| Deformddpm.eval()
|
|
|
|
|
| reg_img_savepath_fullres = hyp_parameters['reg_img_savepath'].rstrip('/') + '_fullres/'
|
| reg_msk_savepath_fullres = hyp_parameters['reg_msk_savepath'].rstrip('/') + '_fullres/'
|
| reg_ddf_savepath_fullres = hyp_parameters['reg_ddf_savepath'].rstrip('/') + '_fullres/'
|
|
|
| os.makedirs(hyp_parameters['reg_img_savepath'], exist_ok=True)
|
| os.makedirs(hyp_parameters['reg_msk_savepath'], exist_ok=True)
|
| os.makedirs(hyp_parameters['reg_ddf_savepath'], exist_ok=True)
|
| os.makedirs(reg_img_savepath_fullres, exist_ok=True)
|
| os.makedirs(reg_msk_savepath_fullres, exist_ok=True)
|
| os.makedirs(reg_ddf_savepath_fullres, exist_ok=True)
|
|
|
|
|
|
|
|
|
| def center_pad_to_cube(volume):
|
| """Pad volume to a cube using the max dimension, with symmetric (center) padding."""
|
| max_dim = max(volume.shape[:3])
|
| pad_width = []
|
| for s in volume.shape[:3]:
|
| total_pad = max_dim - s
|
| pad_before = total_pad // 2
|
| pad_after = total_pad - pad_before
|
| pad_width.append((pad_before, pad_after))
|
|
|
| for _ in range(volume.ndim - 3):
|
| pad_width.append((0, 0))
|
| return np.pad(volume, pad_width, mode='constant', constant_values=0)
|
|
|
|
|
| def load_fullres_volume(key, ds):
|
| """Load original-resolution volume: axis reorder, clamp, normalize, center-pad to cube."""
|
| volume = sitk.ReadImage(key)
|
| volume = sitk.GetArrayFromImage(volume)
|
| volume = reverse_axis_order(volume)
|
| if volume.ndim == 4:
|
| channel_ids = ds.get_channel_ids(key)
|
| channel_id = channel_ids[0] if len(channel_ids) > 0 else 0
|
| volume = volume[:, :, :, channel_id]
|
|
|
| if ds.clamp_range is not None:
|
| modality = ds.ALLdata_filtered[key].get("Modality", None)
|
| if modality == "CT":
|
| volume = np.clip(volume, ds.clamp_range[0], ds.clamp_range[1])
|
| volume = ds.normalize(volume)
|
| volume = center_pad_to_cube(volume)
|
| return volume
|
|
|
|
|
| def load_fullres_label(key, ds, label_key):
|
| """Load original-resolution label: axis reorder, center-pad to cube (no resize)."""
|
| label_path_dict = ds.ALLdata_filtered[key].get('Label_path', {})
|
| task_labels = label_path_dict.get('segmentation', {})
|
| if label_key not in task_labels:
|
| return None
|
| label = sitk.ReadImage(task_labels[label_key])
|
| label = sitk.GetArrayFromImage(label)
|
| label = reverse_axis_order(label)
|
| if label.ndim > 3:
|
| channel_ids = ds.get_channel_ids(key)
|
| if len(channel_ids) != 0:
|
| label = label[..., channel_ids]
|
| label = center_pad_to_cube(label)
|
| return label
|
|
|
|
|
| def apply_ddf(volume_tensor, ddf, padding_mode='border', resample_mode='bilinear'):
|
| """Apply DDF to volume tensor at any resolution.
|
|
|
| The DDF stores fractional displacements (value * max_sz = voxel displacement).
|
| When the DDF is spatially upscaled via trilinear interpolation from model resolution
|
| to full resolution, the fractional values remain correct — we use the new spatial
|
| size as max_sz, which correctly scales the voxel displacement proportionally.
|
| """
|
| device = ddf.device
|
| ndims = 3
|
| img_sz = list(volume_tensor.shape[2:])
|
| max_sz = torch.reshape(
|
| torch.tensor(img_sz, dtype=torch.float32, device=device),
|
| [1, ndims] + [1] * ndims)
|
| ref_grid = torch.reshape(
|
| torch.stack(torch.meshgrid(
|
| [torch.arange(s, device=device) for s in img_sz], indexing='ij'), 0),
|
| [1, ndims] + img_sz)
|
| img_shape = torch.reshape(
|
| torch.tensor([(s - 1) / 2. for s in img_sz], dtype=torch.float32, device=device),
|
| [1] + [1] * ndims + [ndims])
|
| grid = torch.flip(
|
| (ddf * max_sz + ref_grid).permute(
|
| [0] + list(range(2, 2 + ndims)) + [1]) / img_shape - 1,
|
| dims=[-1])
|
| return F.grid_sample(volume_tensor, grid.float(), mode=resample_mode,
|
| padding_mode=padding_mode, align_corners=True)
|
|
|
|
|
|
|
|
|
| keys = list(dataset.ALLdata_filtered.keys())
|
| print("total num of images:", len(keys))
|
|
|
| for e, key in enumerate(tqdm(keys)):
|
| pid = e
|
| print(f'Processing patient {pid}, image {e}, key: {key}')
|
|
|
|
|
| fullres_vol = load_fullres_volume(key, dataset)
|
| orig_sz = list(fullres_vol.shape)
|
| print(f" Full-res padded shape: {orig_sz}")
|
|
|
|
|
| vol_model = resize(fullres_vol, [model_img_sz] * 3,
|
| anti_aliasing=True, preserve_range=True)
|
| img = torch.tensor(vol_model[None, None, :, :, :],
|
| dtype=torch.float32, device=hyp_parameters["device"])
|
|
|
|
|
| fullres_labels = {}
|
| for lk in label_keys:
|
| lab = load_fullres_label(key, dataset, lk)
|
| if lab is not None:
|
| fullres_labels[lk] = lab
|
|
|
|
|
| label_arrays_model = []
|
| label_arrays_fullres = []
|
| for lk in label_keys:
|
| if lk in fullres_labels:
|
| lab = fullres_labels[lk]
|
| lab_model = resize(lab, [model_img_sz] * 3,
|
| anti_aliasing=False, preserve_range=True, order=0)
|
| if lab_model.ndim == 3:
|
| lab_model = lab_model[None, :, :, :]
|
| elif lab_model.ndim > 3:
|
| lab_model = np.transpose(lab_model, (3, 0, 1, 2))
|
| label_arrays_model.append(lab_model)
|
|
|
| if lab.ndim == 3:
|
| lab = lab[None, :, :, :]
|
| elif lab.ndim > 3:
|
| lab = np.transpose(lab, (3, 0, 1, 2))
|
| label_arrays_fullres.append(lab)
|
| else:
|
| label_arrays_model.append(np.full([1] + [model_img_sz] * 3, -1))
|
| label_arrays_fullres.append(np.full([1] + orig_sz, -1))
|
|
|
| if len(label_arrays_model) > 0:
|
| mask_model_np = np.concatenate(label_arrays_model, axis=0)
|
| mask = torch.tensor(mask_model_np[None], dtype=torch.float32,
|
| device=hyp_parameters["device"])
|
| fullres_msk_np = np.concatenate(label_arrays_fullres, axis=0)
|
| fullres_msk_tensor = torch.tensor(fullres_msk_np[None], dtype=torch.float32,
|
| device=hyp_parameters["device"])
|
| else:
|
| mask = None
|
| fullres_msk_np = None
|
| fullres_msk_tensor = None
|
|
|
|
|
| fullres_img_tensor = torch.tensor(fullres_vol[None, None, :, :, :],
|
| dtype=torch.float32,
|
| device=hyp_parameters["device"])
|
|
|
|
|
| if e <= 0:
|
| target_img = img.clone().detach()
|
|
|
|
|
| image_original = img.cpu().numpy()
|
| nib.save(utils.converet_to_nibabel(image_original, ndims=hyp_parameters["ndims"]),
|
| os.path.join(hyp_parameters['reg_img_savepath'],
|
| utils.get_barcode([pid, e]) + '.nii.gz'))
|
| if mask is not None:
|
| mask_original = mask.cpu().numpy()
|
| nib.save(utils.converet_to_nibabel(mask_original, ndims=hyp_parameters["ndims"]),
|
| os.path.join(hyp_parameters['reg_msk_savepath'],
|
| utils.get_barcode([pid, e]) + '_GT.nii.gz'))
|
|
|
|
|
|
|
| nib.save(utils.converet_to_nibabel(fullres_vol[None, None], ndims=hyp_parameters["ndims"]),
|
| os.path.join(reg_img_savepath_fullres,
|
| utils.get_barcode([pid, e]) + '.nii.gz'))
|
| if fullres_msk_np is not None:
|
|
|
| nib.save(utils.converet_to_nibabel(fullres_msk_np[None], ndims=hyp_parameters["ndims"]),
|
| os.path.join(reg_msk_savepath_fullres,
|
| utils.get_barcode([pid, e]) + '_GT.nii.gz'))
|
|
|
|
|
| noise_step = hyp_parameters["start_noise_step"]
|
| with torch.no_grad():
|
| for im in range(1):
|
| print(f' Generating -> Subject-{pid}, Scan-{e} ({im}/{hyp_parameters["aug_coe"]})', end='\r')
|
|
|
| [ddf_comp, ddf_rand], [img_rec, img_diff, img_save], [msk_rec, msk_diff, msk_save] = \
|
| Deformddpm.diff_recover(
|
| img_org=img,
|
| cond_imgs=target_img.clone().detach(),
|
| msk_org=mask,
|
| T=[None, hyp_parameters["timesteps"]],
|
| v_scale=hyp_parameters["v_scale"],
|
| t_save=None,
|
| proc_type=hyp_parameters["condition_type"])
|
|
|
|
|
| denoise_imgs = img_rec.cpu().numpy()
|
| noisy_imgs_np = img_diff.cpu().numpy()
|
|
|
| nib.save(utils.converet_to_nibabel(denoise_imgs, ndims=hyp_parameters["ndims"]),
|
| os.path.join(hyp_parameters['reg_img_savepath'],
|
| utils.get_barcode([pid, e, im, noise_step]) + '.nii.gz'))
|
| nib.save(utils.converet_to_nibabel(noisy_imgs_np, ndims=hyp_parameters["ndims"]),
|
| os.path.join(hyp_parameters['reg_img_savepath'],
|
| utils.get_barcode([pid, e, im, noise_step],
|
| header=['Patient', 'Slice', 'NoiseImg', 'NoiseStep']) + '.nii.gz'))
|
|
|
| if msk_rec is not None:
|
| denoise_msks = msk_rec.cpu().numpy()
|
| noisy_msks_np = msk_diff.cpu().numpy()
|
| nib.save(utils.converet_to_nibabel(denoise_msks, ndims=hyp_parameters["ndims"]),
|
| os.path.join(hyp_parameters['reg_msk_savepath'],
|
| utils.get_barcode([pid, e, im, noise_step]) + '_GT.nii.gz'))
|
| nib.save(utils.converet_to_nibabel(noisy_msks_np, ndims=hyp_parameters["ndims"]),
|
| os.path.join(hyp_parameters['reg_msk_savepath'],
|
| utils.get_barcode([pid, e, im, noise_step],
|
| header=['Patient', 'Slice', 'NoiseImg', 'NoiseStep']) + '_GT.nii.gz'))
|
|
|
|
|
| ddf_fullres = F.interpolate(ddf_comp, size=orig_sz,
|
| mode='trilinear', align_corners=False)
|
| ddf_rand_fullres = F.interpolate(ddf_rand, size=orig_sz,
|
| mode='trilinear', align_corners=False)
|
|
|
|
|
| img_rec_fullres = apply_ddf(fullres_img_tensor, ddf_fullres,
|
| padding_mode='border')
|
| img_noisy_fullres = apply_ddf(fullres_img_tensor, ddf_rand_fullres,
|
| padding_mode='border')
|
|
|
| if fullres_msk_tensor is not None:
|
| msk_rec_fullres = apply_ddf(fullres_msk_tensor, ddf_fullres,
|
| padding_mode='zeros', resample_mode='nearest')
|
| msk_noisy_fullres = apply_ddf(fullres_msk_tensor, ddf_rand_fullres,
|
| padding_mode='zeros', resample_mode='nearest')
|
|
|
|
|
| nib.save(utils.converet_to_nibabel(img_rec_fullres, ndims=hyp_parameters["ndims"]),
|
| os.path.join(reg_img_savepath_fullres,
|
| utils.get_barcode([pid, e, im, noise_step]) + '.nii.gz'))
|
| nib.save(utils.converet_to_nibabel(img_noisy_fullres, ndims=hyp_parameters["ndims"]),
|
| os.path.join(reg_img_savepath_fullres,
|
| utils.get_barcode([pid, e, im, noise_step],
|
| header=['Patient', 'Slice', 'NoiseImg', 'NoiseStep']) + '.nii.gz'))
|
|
|
| if fullres_msk_tensor is not None:
|
| nib.save(utils.converet_to_nibabel(msk_rec_fullres, ndims=hyp_parameters["ndims"]),
|
| os.path.join(reg_msk_savepath_fullres,
|
| utils.get_barcode([pid, e, im, noise_step]) + '_GT.nii.gz'))
|
| nib.save(utils.converet_to_nibabel(msk_noisy_fullres, ndims=hyp_parameters["ndims"]),
|
| os.path.join(reg_msk_savepath_fullres,
|
| utils.get_barcode([pid, e, im, noise_step],
|
| header=['Patient', 'Slice', 'NoiseImg', 'NoiseStep']) + '_GT.nii.gz'))
|
|
|
|
|
| nib.save(utils.converet_to_nibabel(ddf_fullres, ndims=hyp_parameters["ndims"]),
|
| os.path.join(reg_ddf_savepath_fullres,
|
| utils.get_barcode([pid, e, im, noise_step]) + '.nii.gz'))
|
|
|
| if (im - hyp_parameters["start_noise_step"]) % 2 == 0:
|
| noise_step = noise_step + hyp_parameters["noise_step"]
|
|
|
| if e > 5:
|
| break
|
|
|