""" OM_aug_om.py — Augmentation using OMorpher. Drop-in replacement for OM_aug.py. Produces identical outputs but uses OMorpher instead of DeformDDPM + STN + standalone apply_ddf(). Usage: python Scripts/OM_aug_om.py -C Config/config_om.yaml """ import os import sys import argparse # Add project root to path so imports work from Scripts/ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import numpy as np import torch import nibabel as nib import yaml from tqdm import tqdm import utils from Dataloader.dataLoader import OminiDataset_inference_w_all from torch.utils.data import DataLoader from OMorpher import OMorpher # ========== CLI ========== parser = argparse.ArgumentParser() parser.add_argument( "--config", "-C", help="Path for the config file", type=str, default="Config/config_cmr.yaml", required=False, ) args = parser.parse_args() # ========== Config ========== with open(args.config, "r") as file: hyp_parameters = yaml.safe_load(file) print(hyp_parameters) if not os.path.exists(hyp_parameters["aug_img_savepath"]): os.makedirs(hyp_parameters["aug_img_savepath"]) if not os.path.exists(hyp_parameters["aug_msk_savepath"]): os.makedirs(hyp_parameters["aug_msk_savepath"]) if not os.path.exists(hyp_parameters["aug_ddf_savepath"]): os.makedirs(hyp_parameters["aug_ddf_savepath"]) print(hyp_parameters["aug_img_savepath"]) hyp_parameters["batchsize"] = 1 # ========== Dataset (identical to OM_aug.py) ========== select_channels_dict = {} min_crop_ratio = 0.9 label_keys = ["heart"] database = ["MnMs"] subtype = "es" hyp_parameters["aug_img_savepath"] = f"Data/Aug_data/mnms_{subtype}/img/" hyp_parameters["aug_msk_savepath"] = f"Data/Aug_data/mnms_{subtype}/msk/" hyp_parameters["aug_ddf_savepath"] = f"Data/Aug_data/mnms_{subtype}/ddf/" select_channels_dict = {"ImgDict": [subtype]} dataset = OminiDataset_inference_w_all( transform=None, min_crop_ratio=min_crop_ratio, label_key=label_keys, database=database, select_channels_dict=select_channels_dict, ) Infer_Loader = DataLoader( dataset, batch_size=hyp_parameters["batchsize"], shuffle=False, ) # ========== OMorpher setup ========== epoch = f'{hyp_parameters["model_id_str"]}_{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}' model_save_path = os.path.join( f'Models/{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}/', str(epoch) + ".pth", ) print("Loading model from:", model_save_path) om = OMorpher( config=hyp_parameters, checkpoint_path=model_save_path, device=str(hyp_parameters.get("device", "cpu")), ) print(om) # ========== Output directories ========== os.makedirs(hyp_parameters["aug_img_savepath"], exist_ok=True) os.makedirs(hyp_parameters["aug_msk_savepath"], exist_ok=True) os.makedirs(hyp_parameters["aug_ddf_savepath"], exist_ok=True) # ========== Main inference loop ========== device = om.device print("total num of image:", len(Infer_Loader)) for e, d in tqdm(enumerate(Infer_Loader)): img = d["img"] mask = d["labels"] label_str = str(d["label_channels"]) pid = e print("Processing to patient:", pid, " image:", e) img = img.type(torch.float32).to(device) image_original = img.cpu().detach().numpy() mask = mask.type(torch.float32).to(device) mask_original = mask.cpu().detach().numpy() # Save original image and mask nifti_img = utils.converet_to_nibabel(image_original, ndims=hyp_parameters["ndims"]) nifti_mask = utils.converet_to_nibabel(mask_original, ndims=hyp_parameters["ndims"]) nib.save( nifti_img, os.path.join( hyp_parameters["aug_img_savepath"], utils.get_barcode([pid, e]) + ".nii.gz", ), ) nib.save( nifti_mask, os.path.join( hyp_parameters["aug_msk_savepath"], utils.get_barcode([pid, e]) + "_GT.nii.gz", ), ) # Augmentation loop noise_step = hyp_parameters["start_noise_step"] with torch.no_grad(): for im in range(hyp_parameters["aug_coe"]): print( f"Generating -> Subject-{pid}, Scan-{e} " f'({im}/{hyp_parameters["aug_coe"]})', end="\r", ) # 1. Set init image (DataLoader tensor passes through) om.set_init_img(img) # 2. Self-conditioning (matches: cond_imgs = img_org.clone().detach()) om.set_cond_img(img) # 3. Forward diffuse to get noisy image + random DDF t_start = torch.tensor(np.array([noise_step]), device=device) img_diff, _, ddf_rand = om._get_random_ddf(om._init_img, t_start) # 4. Get noisy mask msk_diff = om.apply_def( img=mask, ddf=ddf_rand, padding_mode="zeros", resample_mode="nearest", ) # 5. Set random DDF as initial DDF om.set_init_def(ddf=ddf_rand.clone().detach()) # 6. Run reverse diffusion om.predict( T=[noise_step, hyp_parameters["timesteps"]], proc_type=hyp_parameters["condition_type"], ) # 7. Get recovered outputs ddf_comp = om.get_def() img_rec = om.apply_def(img=img, ddf=ddf_comp, padding_mode="zeros") msk_rec = om.apply_def( img=mask, ddf=ddf_comp, padding_mode="zeros", resample_mode="nearest", ) # Convert to numpy for saving denoise_imgs = img_rec.cpu().detach().numpy() denoise_msks = msk_rec.cpu().detach().numpy() noisy_imgs_np = img_diff.cpu().detach().numpy() noisy_msks_np = msk_diff.cpu().detach().numpy() # Save augmented (recovered) outputs nifti_img_aug = utils.converet_to_nibabel(denoise_imgs, ndims=hyp_parameters["ndims"]) nifti_mask_aug = utils.converet_to_nibabel(denoise_msks, ndims=hyp_parameters["ndims"]) nifti_img = utils.converet_to_nibabel(noisy_imgs_np, ndims=hyp_parameters["ndims"]) nifti_mask = utils.converet_to_nibabel(noisy_msks_np, ndims=hyp_parameters["ndims"]) nib.save( nifti_img_aug, os.path.join( hyp_parameters["aug_img_savepath"], utils.get_barcode([pid, e, im, noise_step]) + ".nii.gz", ), ) nib.save( nifti_mask_aug, os.path.join( hyp_parameters["aug_msk_savepath"], utils.get_barcode([pid, e, im, noise_step]) + "_GT.nii.gz", ), ) # Save noisy image/mask nib.save( nifti_img, os.path.join( hyp_parameters["aug_img_savepath"], utils.get_barcode( [pid, e, im, noise_step], header=["Patient", "Slice", "NoiseImg", "NoiseStep"], ) + ".nii.gz", ), ) nib.save( nifti_mask, os.path.join( hyp_parameters["aug_msk_savepath"], utils.get_barcode( [pid, e, im, noise_step], header=["Patient", "Slice", "NoiseImg", "NoiseStep"], ) + "_GT.nii.gz", ), ) if (im - hyp_parameters["start_noise_step"]) % 2 == 0: noise_step = noise_step + hyp_parameters["noise_step"] if e >= 0: exit()