""" OM_reg_flexres_om.py — Full-resolution registration using OMorpher. Drop-in replacement for OM_reg_flexres.py. Produces identical outputs but uses OMorpher instead of DeformDDPM + STN + standalone apply_ddf(). Usage: python Scripts/OM_reg_flexres_om.py -C Config/config_om.yaml """ import os import sys import argparse # Add project root to path so imports work from Scripts/ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import numpy as np import torch import torch.nn.functional as F import nibabel as nib import yaml import SimpleITK as sitk from tqdm import tqdm import utils from Dataloader.dataLoader import OminiDataset_inference_w_all, reverse_axis_order from OMorpher import OMorpher # ========== CLI ========== parser = argparse.ArgumentParser() parser.add_argument( "--config", "-C", help="Path for the config file", type=str, default="Config/config_om.yaml", required=False, ) args = parser.parse_args() # ========== Config ========== with open(args.config, "r") as file: hyp_parameters = yaml.safe_load(file) print(hyp_parameters) if not os.path.exists(hyp_parameters["aug_img_savepath"]): os.makedirs(hyp_parameters["aug_img_savepath"]) if not os.path.exists(hyp_parameters["aug_msk_savepath"]): os.makedirs(hyp_parameters["aug_msk_savepath"]) if not os.path.exists(hyp_parameters["aug_ddf_savepath"]): os.makedirs(hyp_parameters["aug_ddf_savepath"]) print(hyp_parameters["aug_img_savepath"]) hyp_parameters["batchsize"] = 1 model_img_sz = hyp_parameters["img_size"] # ========== Dataset (unchanged — used only for filtering/metadata) ========== label_keys = ["brain"] database = ["Brats2019"] dataset = OminiDataset_inference_w_all( transform=None, min_crop_ratio=1.0, label_key=label_keys, database=database, ) # ========== OMorpher setup ========== epoch = f'{hyp_parameters["model_id_str"]}_{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}' model_save_path = os.path.join( f'Models/{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}/', str(epoch) + ".pth", ) print("Loading model from:", model_save_path) om = OMorpher( config=hyp_parameters, checkpoint_path=model_save_path, device=str(hyp_parameters.get("device", "cpu")), ) print(om) # ========== Output directories ========== reg_img_savepath_fullres = hyp_parameters["reg_img_savepath"].rstrip("/") + "_fullres/" reg_msk_savepath_fullres = hyp_parameters["reg_msk_savepath"].rstrip("/") + "_fullres/" reg_ddf_savepath_fullres = hyp_parameters["reg_ddf_savepath"].rstrip("/") + "_fullres/" for p in [ hyp_parameters["reg_img_savepath"], hyp_parameters["reg_msk_savepath"], hyp_parameters["reg_ddf_savepath"], reg_img_savepath_fullres, reg_msk_savepath_fullres, reg_ddf_savepath_fullres, ]: os.makedirs(p, exist_ok=True) # ========== Helper: load full-res data (same as original) ========== def center_pad_to_cube(volume): """Pad volume to a cube using the max dimension, with symmetric (center) padding.""" max_dim = max(volume.shape[:3]) pad_width = [] for s in volume.shape[:3]: total_pad = max_dim - s pad_before = total_pad // 2 pad_after = total_pad - pad_before pad_width.append((pad_before, pad_after)) for _ in range(volume.ndim - 3): pad_width.append((0, 0)) return np.pad(volume, pad_width, mode="constant", constant_values=0) def load_fullres_volume(key, ds): """Load original-resolution volume: axis reorder, clamp, normalize, center-pad to cube.""" volume = sitk.ReadImage(key) volume = sitk.GetArrayFromImage(volume) volume = reverse_axis_order(volume) if volume.ndim == 4: channel_ids = ds.get_channel_ids(key) channel_id = channel_ids[0] if len(channel_ids) > 0 else 0 volume = volume[:, :, :, channel_id] if ds.clamp_range is not None: modality = ds.ALLdata_filtered[key].get("Modality", None) if modality == "CT": volume = np.clip(volume, ds.clamp_range[0], ds.clamp_range[1]) volume = ds.normalize(volume) volume = center_pad_to_cube(volume) return volume def load_fullres_label(key, ds, label_key): """Load original-resolution label: axis reorder, center-pad to cube.""" label_path_dict = ds.ALLdata_filtered[key].get("Label_path", {}) task_labels = label_path_dict.get("segmentation", {}) if label_key not in task_labels: return None label = sitk.ReadImage(task_labels[label_key]) label = sitk.GetArrayFromImage(label) label = reverse_axis_order(label) if label.ndim > 3: channel_ids = ds.get_channel_ids(key) if len(channel_ids) != 0: label = label[..., channel_ids] label = center_pad_to_cube(label) return label # ========== Main inference loop ========== keys = list(dataset.ALLdata_filtered.keys()) print("total num of images:", len(keys)) device = om.device for e, key in enumerate(tqdm(keys)): pid = e print(f"Processing patient {pid}, image {e}, key: {key}") # --- Load & standardize volume via OMorpher --- fullres_vol = load_fullres_volume(key, dataset) om.set_init_img(fullres_vol) img = om._init_img # [1, 1, model_sz, model_sz, model_sz] fullres_img_tensor = om._init_img_raw # [1, 1, D, H, W] full-res tensor orig_sz = list(fullres_img_tensor.shape[2:]) print(f" Full-res padded shape: {orig_sz}") # --- Load & standardize labels via OMorpher --- masks_model = [] masks_fullres = [] for lk in label_keys: lab = load_fullres_label(key, dataset, lk) model_t, fullres_t = om._standardize_label(lab) # None → -1 placeholder masks_model.append(model_t) masks_fullres.append(fullres_t) if masks_model: mask = torch.cat(masks_model, dim=1) # [1, C_total, S, S, S] fullres_msk_tensor = torch.cat(masks_fullres, dim=1) # [1, C_total, D, H, W] else: mask = None fullres_msk_tensor = None # --- Save target conditioning image (first subject) --- if e <= 0: target_img = img.clone().detach() # --- Save original images at model resolution --- image_original = img.cpu().numpy() nib.save( utils.converet_to_nibabel(image_original, ndims=hyp_parameters["ndims"]), os.path.join(hyp_parameters["reg_img_savepath"], utils.get_barcode([pid, e]) + ".nii.gz"), ) if mask is not None: mask_original = mask.cpu().numpy() nib.save( utils.converet_to_nibabel(mask_original, ndims=hyp_parameters["ndims"]), os.path.join(hyp_parameters["reg_msk_savepath"], utils.get_barcode([pid, e]) + "_GT.nii.gz"), ) # --- Save original at full-res --- nib.save( utils.converet_to_nibabel(fullres_img_tensor, ndims=hyp_parameters["ndims"]), os.path.join(reg_img_savepath_fullres, utils.get_barcode([pid, e]) + ".nii.gz"), ) if fullres_msk_tensor is not None: nib.save( utils.converet_to_nibabel(fullres_msk_tensor, ndims=hyp_parameters["ndims"]), os.path.join(reg_msk_savepath_fullres, utils.get_barcode([pid, e]) + "_GT.nii.gz"), ) # --- Diffusion recovery via OMorpher --- noise_step = hyp_parameters["start_noise_step"] with torch.no_grad(): for im in range(1): print( f" Generating -> Subject-{pid}, Scan-{e} " f'({im}/{hyp_parameters["aug_coe"]})', end="\r", ) # Set up OMorpher inputs om.set_init_img(img) om.set_cond_img(target_img.clone().detach()) # Run diffusion recovery # T=[None, timesteps] in original means: no initial noise, full reverse diffusion om.predict( T=[None, hyp_parameters["timesteps"]], proc_type=hyp_parameters["condition_type"], ) ddf_comp = om.get_def() # Reconstruct images at model resolution using OMorpher img_rec = om.apply_def(img=img, ddf=ddf_comp, padding_mode="zeros") # --- Save model-resolution results --- denoise_imgs = img_rec.cpu().numpy() nib.save( utils.converet_to_nibabel(denoise_imgs, ndims=hyp_parameters["ndims"]), os.path.join( hyp_parameters["reg_img_savepath"], utils.get_barcode([pid, e, im, noise_step]) + ".nii.gz", ), ) if mask is not None: msk_rec = om.apply_def( img=mask, ddf=ddf_comp, padding_mode="zeros", resample_mode="nearest", ) denoise_msks = msk_rec.cpu().numpy() nib.save( utils.converet_to_nibabel(denoise_msks, ndims=hyp_parameters["ndims"]), os.path.join( hyp_parameters["reg_msk_savepath"], utils.get_barcode([pid, e, im, noise_step]) + "_GT.nii.gz", ), ) # --- Upscale DDF and apply at full resolution via OMorpher --- img_rec_fullres = om.apply_def( img=fullres_img_tensor, ddf=ddf_comp, padding_mode="border", ) if fullres_msk_tensor is not None: msk_rec_fullres = om.apply_def( img=fullres_msk_tensor, ddf=ddf_comp, padding_mode="zeros", resample_mode="nearest", ) # Upscale DDF for saving ddf_fullres = F.interpolate( ddf_comp, size=orig_sz, mode="trilinear", align_corners=False, ) # --- Save full-res results --- nib.save( utils.converet_to_nibabel(img_rec_fullres, ndims=hyp_parameters["ndims"]), os.path.join( reg_img_savepath_fullres, utils.get_barcode([pid, e, im, noise_step]) + ".nii.gz", ), ) if fullres_msk_tensor is not None: nib.save( utils.converet_to_nibabel(msk_rec_fullres, ndims=hyp_parameters["ndims"]), os.path.join( reg_msk_savepath_fullres, utils.get_barcode([pid, e, im, noise_step]) + "_GT.nii.gz", ), ) nib.save( utils.converet_to_nibabel(ddf_fullres, ndims=hyp_parameters["ndims"]), os.path.join( reg_ddf_savepath_fullres, utils.get_barcode([pid, e, im, noise_step]) + ".nii.gz", ), ) if (im - hyp_parameters["start_noise_step"]) % 2 == 0: noise_step = noise_step + hyp_parameters["noise_step"] if e > 5: break