| """
|
| OM_reg_flexres_om.py — Full-resolution registration using OMorpher.
|
|
|
| Drop-in replacement for OM_reg_flexres.py. Produces identical outputs but
|
| uses OMorpher instead of DeformDDPM + STN + standalone apply_ddf().
|
|
|
| Usage:
|
| python Scripts/OM_reg_flexres_om.py -C Config/config_om.yaml
|
| """
|
|
|
| import os
|
| import sys
|
| import argparse
|
|
|
|
|
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
| import numpy as np
|
| import torch
|
| import torch.nn.functional as F
|
| import nibabel as nib
|
| import yaml
|
| import SimpleITK as sitk
|
| from tqdm import tqdm
|
|
|
| import utils
|
| from Dataloader.dataLoader import OminiDataset_inference_w_all, reverse_axis_order
|
| from OMorpher import OMorpher
|
|
|
|
|
|
|
| parser = argparse.ArgumentParser()
|
| parser.add_argument(
|
| "--config", "-C",
|
| help="Path for the config file",
|
| type=str,
|
| default="Config/config_om.yaml",
|
| required=False,
|
| )
|
| args = parser.parse_args()
|
|
|
|
|
|
|
| with open(args.config, "r") as file:
|
| hyp_parameters = yaml.safe_load(file)
|
| print(hyp_parameters)
|
|
|
| if not os.path.exists(hyp_parameters["aug_img_savepath"]):
|
| os.makedirs(hyp_parameters["aug_img_savepath"])
|
| if not os.path.exists(hyp_parameters["aug_msk_savepath"]):
|
| os.makedirs(hyp_parameters["aug_msk_savepath"])
|
| if not os.path.exists(hyp_parameters["aug_ddf_savepath"]):
|
| os.makedirs(hyp_parameters["aug_ddf_savepath"])
|
| print(hyp_parameters["aug_img_savepath"])
|
|
|
| hyp_parameters["batchsize"] = 1
|
| model_img_sz = hyp_parameters["img_size"]
|
|
|
|
|
|
|
| label_keys = ["brain"]
|
| database = ["Brats2019"]
|
|
|
| dataset = OminiDataset_inference_w_all(
|
| transform=None, min_crop_ratio=1.0, label_key=label_keys, database=database,
|
| )
|
|
|
|
|
|
|
| epoch = f'{hyp_parameters["model_id_str"]}_{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}'
|
| model_save_path = os.path.join(
|
| f'Models/{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}/',
|
| str(epoch) + ".pth",
|
| )
|
| print("Loading model from:", model_save_path)
|
|
|
| om = OMorpher(
|
| config=hyp_parameters,
|
| checkpoint_path=model_save_path,
|
| device=str(hyp_parameters.get("device", "cpu")),
|
| )
|
| print(om)
|
|
|
|
|
|
|
| reg_img_savepath_fullres = hyp_parameters["reg_img_savepath"].rstrip("/") + "_fullres/"
|
| reg_msk_savepath_fullres = hyp_parameters["reg_msk_savepath"].rstrip("/") + "_fullres/"
|
| reg_ddf_savepath_fullres = hyp_parameters["reg_ddf_savepath"].rstrip("/") + "_fullres/"
|
|
|
| for p in [
|
| hyp_parameters["reg_img_savepath"],
|
| hyp_parameters["reg_msk_savepath"],
|
| hyp_parameters["reg_ddf_savepath"],
|
| reg_img_savepath_fullres,
|
| reg_msk_savepath_fullres,
|
| reg_ddf_savepath_fullres,
|
| ]:
|
| os.makedirs(p, exist_ok=True)
|
|
|
|
|
|
|
|
|
| def center_pad_to_cube(volume):
|
| """Pad volume to a cube using the max dimension, with symmetric (center) padding."""
|
| max_dim = max(volume.shape[:3])
|
| pad_width = []
|
| for s in volume.shape[:3]:
|
| total_pad = max_dim - s
|
| pad_before = total_pad // 2
|
| pad_after = total_pad - pad_before
|
| pad_width.append((pad_before, pad_after))
|
| for _ in range(volume.ndim - 3):
|
| pad_width.append((0, 0))
|
| return np.pad(volume, pad_width, mode="constant", constant_values=0)
|
|
|
|
|
| def load_fullres_volume(key, ds):
|
| """Load original-resolution volume: axis reorder, clamp, normalize, center-pad to cube."""
|
| volume = sitk.ReadImage(key)
|
| volume = sitk.GetArrayFromImage(volume)
|
| volume = reverse_axis_order(volume)
|
| if volume.ndim == 4:
|
| channel_ids = ds.get_channel_ids(key)
|
| channel_id = channel_ids[0] if len(channel_ids) > 0 else 0
|
| volume = volume[:, :, :, channel_id]
|
| if ds.clamp_range is not None:
|
| modality = ds.ALLdata_filtered[key].get("Modality", None)
|
| if modality == "CT":
|
| volume = np.clip(volume, ds.clamp_range[0], ds.clamp_range[1])
|
| volume = ds.normalize(volume)
|
| volume = center_pad_to_cube(volume)
|
| return volume
|
|
|
|
|
| def load_fullres_label(key, ds, label_key):
|
| """Load original-resolution label: axis reorder, center-pad to cube."""
|
| label_path_dict = ds.ALLdata_filtered[key].get("Label_path", {})
|
| task_labels = label_path_dict.get("segmentation", {})
|
| if label_key not in task_labels:
|
| return None
|
| label = sitk.ReadImage(task_labels[label_key])
|
| label = sitk.GetArrayFromImage(label)
|
| label = reverse_axis_order(label)
|
| if label.ndim > 3:
|
| channel_ids = ds.get_channel_ids(key)
|
| if len(channel_ids) != 0:
|
| label = label[..., channel_ids]
|
| label = center_pad_to_cube(label)
|
| return label
|
|
|
|
|
|
|
|
|
| keys = list(dataset.ALLdata_filtered.keys())
|
| print("total num of images:", len(keys))
|
| device = om.device
|
|
|
| for e, key in enumerate(tqdm(keys)):
|
| pid = e
|
| print(f"Processing patient {pid}, image {e}, key: {key}")
|
|
|
|
|
| fullres_vol = load_fullres_volume(key, dataset)
|
| om.set_init_img(fullres_vol)
|
| img = om._init_img
|
| fullres_img_tensor = om._init_img_raw
|
| orig_sz = list(fullres_img_tensor.shape[2:])
|
| print(f" Full-res padded shape: {orig_sz}")
|
|
|
|
|
| masks_model = []
|
| masks_fullres = []
|
| for lk in label_keys:
|
| lab = load_fullres_label(key, dataset, lk)
|
| model_t, fullres_t = om._standardize_label(lab)
|
| masks_model.append(model_t)
|
| masks_fullres.append(fullres_t)
|
|
|
| if masks_model:
|
| mask = torch.cat(masks_model, dim=1)
|
| fullres_msk_tensor = torch.cat(masks_fullres, dim=1)
|
| else:
|
| mask = None
|
| fullres_msk_tensor = None
|
|
|
|
|
| if e <= 0:
|
| target_img = img.clone().detach()
|
|
|
|
|
| image_original = img.cpu().numpy()
|
| nib.save(
|
| utils.converet_to_nibabel(image_original, ndims=hyp_parameters["ndims"]),
|
| os.path.join(hyp_parameters["reg_img_savepath"],
|
| utils.get_barcode([pid, e]) + ".nii.gz"),
|
| )
|
| if mask is not None:
|
| mask_original = mask.cpu().numpy()
|
| nib.save(
|
| utils.converet_to_nibabel(mask_original, ndims=hyp_parameters["ndims"]),
|
| os.path.join(hyp_parameters["reg_msk_savepath"],
|
| utils.get_barcode([pid, e]) + "_GT.nii.gz"),
|
| )
|
|
|
|
|
| nib.save(
|
| utils.converet_to_nibabel(fullres_img_tensor, ndims=hyp_parameters["ndims"]),
|
| os.path.join(reg_img_savepath_fullres,
|
| utils.get_barcode([pid, e]) + ".nii.gz"),
|
| )
|
| if fullres_msk_tensor is not None:
|
| nib.save(
|
| utils.converet_to_nibabel(fullres_msk_tensor, ndims=hyp_parameters["ndims"]),
|
| os.path.join(reg_msk_savepath_fullres,
|
| utils.get_barcode([pid, e]) + "_GT.nii.gz"),
|
| )
|
|
|
|
|
| noise_step = hyp_parameters["start_noise_step"]
|
| with torch.no_grad():
|
| for im in range(1):
|
| print(
|
| f" Generating -> Subject-{pid}, Scan-{e} "
|
| f'({im}/{hyp_parameters["aug_coe"]})',
|
| end="\r",
|
| )
|
|
|
|
|
| om.set_init_img(img)
|
| om.set_cond_img(target_img.clone().detach())
|
|
|
|
|
|
|
| om.predict(
|
| T=[None, hyp_parameters["timesteps"]],
|
| proc_type=hyp_parameters["condition_type"],
|
| )
|
|
|
| ddf_comp = om.get_def()
|
|
|
|
|
| img_rec = om.apply_def(img=img, ddf=ddf_comp, padding_mode="zeros")
|
|
|
|
|
| denoise_imgs = img_rec.cpu().numpy()
|
|
|
| nib.save(
|
| utils.converet_to_nibabel(denoise_imgs, ndims=hyp_parameters["ndims"]),
|
| os.path.join(
|
| hyp_parameters["reg_img_savepath"],
|
| utils.get_barcode([pid, e, im, noise_step]) + ".nii.gz",
|
| ),
|
| )
|
|
|
| if mask is not None:
|
| msk_rec = om.apply_def(
|
| img=mask, ddf=ddf_comp,
|
| padding_mode="zeros", resample_mode="nearest",
|
| )
|
| denoise_msks = msk_rec.cpu().numpy()
|
| nib.save(
|
| utils.converet_to_nibabel(denoise_msks, ndims=hyp_parameters["ndims"]),
|
| os.path.join(
|
| hyp_parameters["reg_msk_savepath"],
|
| utils.get_barcode([pid, e, im, noise_step]) + "_GT.nii.gz",
|
| ),
|
| )
|
|
|
|
|
| img_rec_fullres = om.apply_def(
|
| img=fullres_img_tensor, ddf=ddf_comp, padding_mode="border",
|
| )
|
|
|
| if fullres_msk_tensor is not None:
|
| msk_rec_fullres = om.apply_def(
|
| img=fullres_msk_tensor, ddf=ddf_comp,
|
| padding_mode="zeros", resample_mode="nearest",
|
| )
|
|
|
|
|
| ddf_fullres = F.interpolate(
|
| ddf_comp, size=orig_sz, mode="trilinear", align_corners=False,
|
| )
|
|
|
|
|
| nib.save(
|
| utils.converet_to_nibabel(img_rec_fullres, ndims=hyp_parameters["ndims"]),
|
| os.path.join(
|
| reg_img_savepath_fullres,
|
| utils.get_barcode([pid, e, im, noise_step]) + ".nii.gz",
|
| ),
|
| )
|
|
|
| if fullres_msk_tensor is not None:
|
| nib.save(
|
| utils.converet_to_nibabel(msk_rec_fullres, ndims=hyp_parameters["ndims"]),
|
| os.path.join(
|
| reg_msk_savepath_fullres,
|
| utils.get_barcode([pid, e, im, noise_step]) + "_GT.nii.gz",
|
| ),
|
| )
|
|
|
| nib.save(
|
| utils.converet_to_nibabel(ddf_fullres, ndims=hyp_parameters["ndims"]),
|
| os.path.join(
|
| reg_ddf_savepath_fullres,
|
| utils.get_barcode([pid, e, im, noise_step]) + ".nii.gz",
|
| ),
|
| )
|
|
|
| if (im - hyp_parameters["start_noise_step"]) % 2 == 0:
|
| noise_step = noise_step + hyp_parameters["noise_step"]
|
|
|
| if e > 5:
|
| break
|
|
|