Omini3D / Scripts /OM_aug_om.py
maxmo2009's picture
Sync from local: code + epoch-110 checkpoint, clean README
2af0e94 verified
"""
OM_aug_om.py — Augmentation using OMorpher.
Drop-in replacement for OM_aug.py. Produces identical outputs but uses
OMorpher instead of DeformDDPM + STN + standalone apply_ddf().
Usage:
python Scripts/OM_aug_om.py -C Config/config_om.yaml
"""
import os
import sys
import argparse
# Add project root to path so imports work from Scripts/
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import numpy as np
import torch
import nibabel as nib
import yaml
from tqdm import tqdm
import utils
from Dataloader.dataLoader import OminiDataset_inference_w_all
from torch.utils.data import DataLoader
from OMorpher import OMorpher
# ========== CLI ==========
parser = argparse.ArgumentParser()
parser.add_argument(
"--config", "-C",
help="Path for the config file",
type=str,
default="Config/config_cmr.yaml",
required=False,
)
args = parser.parse_args()
# ========== Config ==========
with open(args.config, "r") as file:
hyp_parameters = yaml.safe_load(file)
print(hyp_parameters)
if not os.path.exists(hyp_parameters["aug_img_savepath"]):
os.makedirs(hyp_parameters["aug_img_savepath"])
if not os.path.exists(hyp_parameters["aug_msk_savepath"]):
os.makedirs(hyp_parameters["aug_msk_savepath"])
if not os.path.exists(hyp_parameters["aug_ddf_savepath"]):
os.makedirs(hyp_parameters["aug_ddf_savepath"])
print(hyp_parameters["aug_img_savepath"])
hyp_parameters["batchsize"] = 1
# ========== Dataset (identical to OM_aug.py) ==========
select_channels_dict = {}
min_crop_ratio = 0.9
label_keys = ["heart"]
database = ["MnMs"]
subtype = "es"
hyp_parameters["aug_img_savepath"] = f"Data/Aug_data/mnms_{subtype}/img/"
hyp_parameters["aug_msk_savepath"] = f"Data/Aug_data/mnms_{subtype}/msk/"
hyp_parameters["aug_ddf_savepath"] = f"Data/Aug_data/mnms_{subtype}/ddf/"
select_channels_dict = {"ImgDict": [subtype]}
dataset = OminiDataset_inference_w_all(
transform=None,
min_crop_ratio=min_crop_ratio,
label_key=label_keys,
database=database,
select_channels_dict=select_channels_dict,
)
Infer_Loader = DataLoader(
dataset,
batch_size=hyp_parameters["batchsize"],
shuffle=False,
)
# ========== OMorpher setup ==========
epoch = f'{hyp_parameters["model_id_str"]}_{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}'
model_save_path = os.path.join(
f'Models/{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}/',
str(epoch) + ".pth",
)
print("Loading model from:", model_save_path)
om = OMorpher(
config=hyp_parameters,
checkpoint_path=model_save_path,
device=str(hyp_parameters.get("device", "cpu")),
)
print(om)
# ========== Output directories ==========
os.makedirs(hyp_parameters["aug_img_savepath"], exist_ok=True)
os.makedirs(hyp_parameters["aug_msk_savepath"], exist_ok=True)
os.makedirs(hyp_parameters["aug_ddf_savepath"], exist_ok=True)
# ========== Main inference loop ==========
device = om.device
print("total num of image:", len(Infer_Loader))
for e, d in tqdm(enumerate(Infer_Loader)):
img = d["img"]
mask = d["labels"]
label_str = str(d["label_channels"])
pid = e
print("Processing to patient:", pid, " image:", e)
img = img.type(torch.float32).to(device)
image_original = img.cpu().detach().numpy()
mask = mask.type(torch.float32).to(device)
mask_original = mask.cpu().detach().numpy()
# Save original image and mask
nifti_img = utils.converet_to_nibabel(image_original, ndims=hyp_parameters["ndims"])
nifti_mask = utils.converet_to_nibabel(mask_original, ndims=hyp_parameters["ndims"])
nib.save(
nifti_img,
os.path.join(
hyp_parameters["aug_img_savepath"],
utils.get_barcode([pid, e]) + ".nii.gz",
),
)
nib.save(
nifti_mask,
os.path.join(
hyp_parameters["aug_msk_savepath"],
utils.get_barcode([pid, e]) + "_GT.nii.gz",
),
)
# Augmentation loop
noise_step = hyp_parameters["start_noise_step"]
with torch.no_grad():
for im in range(hyp_parameters["aug_coe"]):
print(
f"Generating -> Subject-{pid}, Scan-{e} "
f'({im}/{hyp_parameters["aug_coe"]})',
end="\r",
)
# 1. Set init image (DataLoader tensor passes through)
om.set_init_img(img)
# 2. Self-conditioning (matches: cond_imgs = img_org.clone().detach())
om.set_cond_img(img)
# 3. Forward diffuse to get noisy image + random DDF
t_start = torch.tensor(np.array([noise_step]), device=device)
img_diff, _, ddf_rand = om._get_random_ddf(om._init_img, t_start)
# 4. Get noisy mask
msk_diff = om.apply_def(
img=mask, ddf=ddf_rand,
padding_mode="zeros", resample_mode="nearest",
)
# 5. Set random DDF as initial DDF
om.set_init_def(ddf=ddf_rand.clone().detach())
# 6. Run reverse diffusion
om.predict(
T=[noise_step, hyp_parameters["timesteps"]],
proc_type=hyp_parameters["condition_type"],
)
# 7. Get recovered outputs
ddf_comp = om.get_def()
img_rec = om.apply_def(img=img, ddf=ddf_comp, padding_mode="zeros")
msk_rec = om.apply_def(
img=mask, ddf=ddf_comp,
padding_mode="zeros", resample_mode="nearest",
)
# Convert to numpy for saving
denoise_imgs = img_rec.cpu().detach().numpy()
denoise_msks = msk_rec.cpu().detach().numpy()
noisy_imgs_np = img_diff.cpu().detach().numpy()
noisy_msks_np = msk_diff.cpu().detach().numpy()
# Save augmented (recovered) outputs
nifti_img_aug = utils.converet_to_nibabel(denoise_imgs, ndims=hyp_parameters["ndims"])
nifti_mask_aug = utils.converet_to_nibabel(denoise_msks, ndims=hyp_parameters["ndims"])
nifti_img = utils.converet_to_nibabel(noisy_imgs_np, ndims=hyp_parameters["ndims"])
nifti_mask = utils.converet_to_nibabel(noisy_msks_np, ndims=hyp_parameters["ndims"])
nib.save(
nifti_img_aug,
os.path.join(
hyp_parameters["aug_img_savepath"],
utils.get_barcode([pid, e, im, noise_step]) + ".nii.gz",
),
)
nib.save(
nifti_mask_aug,
os.path.join(
hyp_parameters["aug_msk_savepath"],
utils.get_barcode([pid, e, im, noise_step]) + "_GT.nii.gz",
),
)
# Save noisy image/mask
nib.save(
nifti_img,
os.path.join(
hyp_parameters["aug_img_savepath"],
utils.get_barcode(
[pid, e, im, noise_step],
header=["Patient", "Slice", "NoiseImg", "NoiseStep"],
) + ".nii.gz",
),
)
nib.save(
nifti_mask,
os.path.join(
hyp_parameters["aug_msk_savepath"],
utils.get_barcode(
[pid, e, im, noise_step],
header=["Patient", "Slice", "NoiseImg", "NoiseStep"],
) + "_GT.nii.gz",
),
)
if (im - hyp_parameters["start_noise_step"]) % 2 == 0:
noise_step = noise_step + hyp_parameters["noise_step"]
if e >= 0:
exit()