import torch import torchvision from torch import nn from torchvision.utils import save_image from torch.utils.data import DataLoader from torch.optim import Adam from torchvision.utils import make_grid from Diffusion.diffuser import DeformDDPM from Diffusion.networks import get_net, STN from torchvision.transforms import Lambda import random import os import utils from Dataloader.dataloader0 import get_dataloader from Dataloader.dataLoader import * from torchvision.utils import save_image from einops import rearrange, reduce, repeat import numpy as np import nibabel as nib from tqdm import tqdm import yaml import argparse import torch.nn.functional as F import SimpleITK as sitk from skimage.transform import resize EPS = 10e-8 parser = argparse.ArgumentParser() parser.add_argument( "--config", "-C", help="Path for the config file", type=str, default="Config/config_om.yaml", required=False, ) args = parser.parse_args() #======================================================================================================================= # Load the YAML file into a dictionary with open(args.config, 'r') as file: hyp_parameters = yaml.safe_load(file) print(hyp_parameters) if not os.path.exists(hyp_parameters["aug_img_savepath"]): os.makedirs(hyp_parameters["aug_img_savepath"]) if not os.path.exists(hyp_parameters["aug_msk_savepath"]): os.makedirs(hyp_parameters["aug_msk_savepath"]) if not os.path.exists(hyp_parameters["aug_ddf_savepath"]): os.makedirs(hyp_parameters["aug_ddf_savepath"]) print(hyp_parameters["aug_img_savepath"]) hyp_parameters['batchsize'] = 1 model_img_sz = hyp_parameters['img_size'] # e.g. 128 # ======================================================================================================================= # Dataset is used only for its filtering logic (to get the right set of keys + metadata). # We bypass the DataLoader and load volumes directly to ensure deterministic center-padding # that is identical between the 128^3 model input and the full-res volume. label_keys = ['brain'] database = ['Brats2019'] dataset = OminiDataset_inference_w_all( transform=None, min_crop_ratio=1.0, label_key=label_keys, database=database) # ======================================================================================================================= epoch=f'{hyp_parameters["model_id_str"]}_{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}' model_save_path = f'Models/{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}/' model_save_path = os.path.join(model_save_path, str(epoch)+'.pth') Net = get_net(hyp_parameters["net_name"]) Deformddpm = DeformDDPM( network=Net(n_steps = hyp_parameters["timesteps"], ndims = hyp_parameters["ndims"], num_input_chn = hyp_parameters["num_input_chn"], res = model_img_sz ), n_steps = hyp_parameters["timesteps"], image_chw = [hyp_parameters["num_input_chn"]] + [model_img_sz]*hyp_parameters["ndims"], device = hyp_parameters["device"], batch_size = hyp_parameters["batchsize"], img_pad_mode = hyp_parameters["img_pad_mode"], ddf_pad_mode = hyp_parameters["ddf_pad_mode"], padding_mode = hyp_parameters["padding_mode"], v_scale = hyp_parameters["v_scale"], resample_mode = hyp_parameters["resample_mode"], inf_mode = True, ) Deformddpm.to(hyp_parameters["device"]) ddf_stn = STN( img_sz = model_img_sz, ndims = hyp_parameters["ndims"], padding_mode = hyp_parameters['padding_mode'], device = hyp_parameters["device"], ) ddf_stn.to(hyp_parameters["device"]) print("Loading model from:", model_save_path) checkpoint = torch.load(model_save_path, map_location='cpu') Deformddpm.load_state_dict(checkpoint['model_state_dict']) Deformddpm.eval() # Full-res output directories (append _fullres to the standard paths) reg_img_savepath_fullres = hyp_parameters['reg_img_savepath'].rstrip('/') + '_fullres/' reg_msk_savepath_fullres = hyp_parameters['reg_msk_savepath'].rstrip('/') + '_fullres/' reg_ddf_savepath_fullres = hyp_parameters['reg_ddf_savepath'].rstrip('/') + '_fullres/' os.makedirs(hyp_parameters['reg_img_savepath'], exist_ok=True) os.makedirs(hyp_parameters['reg_msk_savepath'], exist_ok=True) os.makedirs(hyp_parameters['reg_ddf_savepath'], exist_ok=True) os.makedirs(reg_img_savepath_fullres, exist_ok=True) os.makedirs(reg_msk_savepath_fullres, exist_ok=True) os.makedirs(reg_ddf_savepath_fullres, exist_ok=True) # ========== Helper functions ========== def center_pad_to_cube(volume): """Pad volume to a cube using the max dimension, with symmetric (center) padding.""" max_dim = max(volume.shape[:3]) pad_width = [] for s in volume.shape[:3]: total_pad = max_dim - s pad_before = total_pad // 2 pad_after = total_pad - pad_before pad_width.append((pad_before, pad_after)) # Handle extra dims (e.g., multi-channel labels) for _ in range(volume.ndim - 3): pad_width.append((0, 0)) return np.pad(volume, pad_width, mode='constant', constant_values=0) def load_fullres_volume(key, ds): """Load original-resolution volume: axis reorder, clamp, normalize, center-pad to cube.""" volume = sitk.ReadImage(key) volume = sitk.GetArrayFromImage(volume) volume = reverse_axis_order(volume) if volume.ndim == 4: channel_ids = ds.get_channel_ids(key) channel_id = channel_ids[0] if len(channel_ids) > 0 else 0 volume = volume[:, :, :, channel_id] # CT clamping if ds.clamp_range is not None: modality = ds.ALLdata_filtered[key].get("Modality", None) if modality == "CT": volume = np.clip(volume, ds.clamp_range[0], ds.clamp_range[1]) volume = ds.normalize(volume) volume = center_pad_to_cube(volume) return volume # shape: [D, D, D] (cubic) def load_fullres_label(key, ds, label_key): """Load original-resolution label: axis reorder, center-pad to cube (no resize).""" label_path_dict = ds.ALLdata_filtered[key].get('Label_path', {}) task_labels = label_path_dict.get('segmentation', {}) if label_key not in task_labels: return None label = sitk.ReadImage(task_labels[label_key]) label = sitk.GetArrayFromImage(label) label = reverse_axis_order(label) if label.ndim > 3: channel_ids = ds.get_channel_ids(key) if len(channel_ids) != 0: label = label[..., channel_ids] label = center_pad_to_cube(label) return label def apply_ddf(volume_tensor, ddf, padding_mode='border', resample_mode='bilinear'): """Apply DDF to volume tensor at any resolution. The DDF stores fractional displacements (value * max_sz = voxel displacement). When the DDF is spatially upscaled via trilinear interpolation from model resolution to full resolution, the fractional values remain correct — we use the new spatial size as max_sz, which correctly scales the voxel displacement proportionally. """ device = ddf.device ndims = 3 img_sz = list(volume_tensor.shape[2:]) max_sz = torch.reshape( torch.tensor(img_sz, dtype=torch.float32, device=device), [1, ndims] + [1] * ndims) ref_grid = torch.reshape( torch.stack(torch.meshgrid( [torch.arange(s, device=device) for s in img_sz], indexing='ij'), 0), [1, ndims] + img_sz) img_shape = torch.reshape( torch.tensor([(s - 1) / 2. for s in img_sz], dtype=torch.float32, device=device), [1] + [1] * ndims + [ndims]) grid = torch.flip( (ddf * max_sz + ref_grid).permute( [0] + list(range(2, 2 + ndims)) + [1]) / img_shape - 1, dims=[-1]) return F.grid_sample(volume_tensor, grid.float(), mode=resample_mode, padding_mode=padding_mode, align_corners=True) # ========== Main inference loop ========== keys = list(dataset.ALLdata_filtered.keys()) print("total num of images:", len(keys)) for e, key in enumerate(tqdm(keys)): pid = e print(f'Processing patient {pid}, image {e}, key: {key}') # --- Load full-resolution volume (center-padded to cube) --- fullres_vol = load_fullres_volume(key, dataset) orig_sz = list(fullres_vol.shape) # e.g. [240, 240, 240] print(f" Full-res padded shape: {orig_sz}") # --- Resize to model resolution for inference --- vol_model = resize(fullres_vol, [model_img_sz] * 3, anti_aliasing=True, preserve_range=True) img = torch.tensor(vol_model[None, None, :, :, :], dtype=torch.float32, device=hyp_parameters["device"]) # --- Load full-res labels and resize to model resolution --- fullres_labels = {} for lk in label_keys: lab = load_fullres_label(key, dataset, lk) if lab is not None: fullres_labels[lk] = lab # Build mask at model resolution (128^3) label_arrays_model = [] label_arrays_fullres = [] for lk in label_keys: if lk in fullres_labels: lab = fullres_labels[lk] lab_model = resize(lab, [model_img_sz] * 3, anti_aliasing=False, preserve_range=True, order=0) if lab_model.ndim == 3: lab_model = lab_model[None, :, :, :] elif lab_model.ndim > 3: lab_model = np.transpose(lab_model, (3, 0, 1, 2)) label_arrays_model.append(lab_model) if lab.ndim == 3: lab = lab[None, :, :, :] elif lab.ndim > 3: lab = np.transpose(lab, (3, 0, 1, 2)) label_arrays_fullres.append(lab) else: label_arrays_model.append(np.full([1] + [model_img_sz] * 3, -1)) label_arrays_fullres.append(np.full([1] + orig_sz, -1)) if len(label_arrays_model) > 0: mask_model_np = np.concatenate(label_arrays_model, axis=0) mask = torch.tensor(mask_model_np[None], dtype=torch.float32, device=hyp_parameters["device"]) fullres_msk_np = np.concatenate(label_arrays_fullres, axis=0) fullres_msk_tensor = torch.tensor(fullres_msk_np[None], dtype=torch.float32, device=hyp_parameters["device"]) else: mask = None fullres_msk_np = None fullres_msk_tensor = None # Build full-res image tensor fullres_img_tensor = torch.tensor(fullres_vol[None, None, :, :, :], dtype=torch.float32, device=hyp_parameters["device"]) # --- Save target conditioning image (first subject) --- if e <= 0: target_img = img.clone().detach() # --- Save original images at 128^3 --- image_original = img.cpu().numpy() nib.save(utils.converet_to_nibabel(image_original, ndims=hyp_parameters["ndims"]), os.path.join(hyp_parameters['reg_img_savepath'], utils.get_barcode([pid, e]) + '.nii.gz')) if mask is not None: mask_original = mask.cpu().numpy() nib.save(utils.converet_to_nibabel(mask_original, ndims=hyp_parameters["ndims"]), os.path.join(hyp_parameters['reg_msk_savepath'], utils.get_barcode([pid, e]) + '_GT.nii.gz')) # --- Save original at full-res --- # fullres_vol is [D,D,D], wrap as [1,1,D,D,D] for converet_to_nibabel nib.save(utils.converet_to_nibabel(fullres_vol[None, None], ndims=hyp_parameters["ndims"]), os.path.join(reg_img_savepath_fullres, utils.get_barcode([pid, e]) + '.nii.gz')) if fullres_msk_np is not None: # fullres_msk_np is [C,D,D,D], wrap as [1,C,D,D,D] nib.save(utils.converet_to_nibabel(fullres_msk_np[None], ndims=hyp_parameters["ndims"]), os.path.join(reg_msk_savepath_fullres, utils.get_barcode([pid, e]) + '_GT.nii.gz')) # --- Diffusion recovery at model resolution --- noise_step = hyp_parameters["start_noise_step"] with torch.no_grad(): for im in range(1): print(f' Generating -> Subject-{pid}, Scan-{e} ({im}/{hyp_parameters["aug_coe"]})', end='\r') [ddf_comp, ddf_rand], [img_rec, img_diff, img_save], [msk_rec, msk_diff, msk_save] = \ Deformddpm.diff_recover( img_org=img, cond_imgs=target_img.clone().detach(), msk_org=mask, T=[None, hyp_parameters["timesteps"]], v_scale=hyp_parameters["v_scale"], t_save=None, proc_type=hyp_parameters["condition_type"]) # --- Save 128^3 results (same as OM_reg.py) --- denoise_imgs = img_rec.cpu().numpy() noisy_imgs_np = img_diff.cpu().numpy() nib.save(utils.converet_to_nibabel(denoise_imgs, ndims=hyp_parameters["ndims"]), os.path.join(hyp_parameters['reg_img_savepath'], utils.get_barcode([pid, e, im, noise_step]) + '.nii.gz')) nib.save(utils.converet_to_nibabel(noisy_imgs_np, ndims=hyp_parameters["ndims"]), os.path.join(hyp_parameters['reg_img_savepath'], utils.get_barcode([pid, e, im, noise_step], header=['Patient', 'Slice', 'NoiseImg', 'NoiseStep']) + '.nii.gz')) if msk_rec is not None: denoise_msks = msk_rec.cpu().numpy() noisy_msks_np = msk_diff.cpu().numpy() nib.save(utils.converet_to_nibabel(denoise_msks, ndims=hyp_parameters["ndims"]), os.path.join(hyp_parameters['reg_msk_savepath'], utils.get_barcode([pid, e, im, noise_step]) + '_GT.nii.gz')) nib.save(utils.converet_to_nibabel(noisy_msks_np, ndims=hyp_parameters["ndims"]), os.path.join(hyp_parameters['reg_msk_savepath'], utils.get_barcode([pid, e, im, noise_step], header=['Patient', 'Slice', 'NoiseImg', 'NoiseStep']) + '_GT.nii.gz')) # --- Upscale DDFs to original resolution --- ddf_fullres = F.interpolate(ddf_comp, size=orig_sz, mode='trilinear', align_corners=False) ddf_rand_fullres = F.interpolate(ddf_rand, size=orig_sz, mode='trilinear', align_corners=False) # --- Apply DDFs at original resolution --- img_rec_fullres = apply_ddf(fullres_img_tensor, ddf_fullres, padding_mode='border') img_noisy_fullres = apply_ddf(fullres_img_tensor, ddf_rand_fullres, padding_mode='border') if fullres_msk_tensor is not None: msk_rec_fullres = apply_ddf(fullres_msk_tensor, ddf_fullres, padding_mode='zeros', resample_mode='nearest') msk_noisy_fullres = apply_ddf(fullres_msk_tensor, ddf_rand_fullres, padding_mode='zeros', resample_mode='nearest') # --- Save full-res results --- nib.save(utils.converet_to_nibabel(img_rec_fullres, ndims=hyp_parameters["ndims"]), os.path.join(reg_img_savepath_fullres, utils.get_barcode([pid, e, im, noise_step]) + '.nii.gz')) nib.save(utils.converet_to_nibabel(img_noisy_fullres, ndims=hyp_parameters["ndims"]), os.path.join(reg_img_savepath_fullres, utils.get_barcode([pid, e, im, noise_step], header=['Patient', 'Slice', 'NoiseImg', 'NoiseStep']) + '.nii.gz')) if fullres_msk_tensor is not None: nib.save(utils.converet_to_nibabel(msk_rec_fullres, ndims=hyp_parameters["ndims"]), os.path.join(reg_msk_savepath_fullres, utils.get_barcode([pid, e, im, noise_step]) + '_GT.nii.gz')) nib.save(utils.converet_to_nibabel(msk_noisy_fullres, ndims=hyp_parameters["ndims"]), os.path.join(reg_msk_savepath_fullres, utils.get_barcode([pid, e, im, noise_step], header=['Patient', 'Slice', 'NoiseImg', 'NoiseStep']) + '_GT.nii.gz')) # Save full-res DDF (converet_to_nibabel handles multi-channel → channel-last) nib.save(utils.converet_to_nibabel(ddf_fullres, ndims=hyp_parameters["ndims"]), os.path.join(reg_ddf_savepath_fullres, utils.get_barcode([pid, e, im, noise_step]) + '.nii.gz')) if (im - hyp_parameters["start_noise_step"]) % 2 == 0: noise_step = noise_step + hyp_parameters["noise_step"] if e > 5: break