Omini3D / Scripts /OM_reg_flexres_om.py
maxmo2009's picture
Sync from local: code + epoch-110 checkpoint, clean README
2af0e94 verified
"""
OM_reg_flexres_om.py — Full-resolution registration using OMorpher.
Drop-in replacement for OM_reg_flexres.py. Produces identical outputs but
uses OMorpher instead of DeformDDPM + STN + standalone apply_ddf().
Usage:
python Scripts/OM_reg_flexres_om.py -C Config/config_om.yaml
"""
import os
import sys
import argparse
# Add project root to path so imports work from Scripts/
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import numpy as np
import torch
import torch.nn.functional as F
import nibabel as nib
import yaml
import SimpleITK as sitk
from tqdm import tqdm
import utils
from Dataloader.dataLoader import OminiDataset_inference_w_all, reverse_axis_order
from OMorpher import OMorpher
# ========== CLI ==========
parser = argparse.ArgumentParser()
parser.add_argument(
"--config", "-C",
help="Path for the config file",
type=str,
default="Config/config_om.yaml",
required=False,
)
args = parser.parse_args()
# ========== Config ==========
with open(args.config, "r") as file:
hyp_parameters = yaml.safe_load(file)
print(hyp_parameters)
if not os.path.exists(hyp_parameters["aug_img_savepath"]):
os.makedirs(hyp_parameters["aug_img_savepath"])
if not os.path.exists(hyp_parameters["aug_msk_savepath"]):
os.makedirs(hyp_parameters["aug_msk_savepath"])
if not os.path.exists(hyp_parameters["aug_ddf_savepath"]):
os.makedirs(hyp_parameters["aug_ddf_savepath"])
print(hyp_parameters["aug_img_savepath"])
hyp_parameters["batchsize"] = 1
model_img_sz = hyp_parameters["img_size"]
# ========== Dataset (unchanged — used only for filtering/metadata) ==========
label_keys = ["brain"]
database = ["Brats2019"]
dataset = OminiDataset_inference_w_all(
transform=None, min_crop_ratio=1.0, label_key=label_keys, database=database,
)
# ========== OMorpher setup ==========
epoch = f'{hyp_parameters["model_id_str"]}_{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}'
model_save_path = os.path.join(
f'Models/{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}/',
str(epoch) + ".pth",
)
print("Loading model from:", model_save_path)
om = OMorpher(
config=hyp_parameters,
checkpoint_path=model_save_path,
device=str(hyp_parameters.get("device", "cpu")),
)
print(om)
# ========== Output directories ==========
reg_img_savepath_fullres = hyp_parameters["reg_img_savepath"].rstrip("/") + "_fullres/"
reg_msk_savepath_fullres = hyp_parameters["reg_msk_savepath"].rstrip("/") + "_fullres/"
reg_ddf_savepath_fullres = hyp_parameters["reg_ddf_savepath"].rstrip("/") + "_fullres/"
for p in [
hyp_parameters["reg_img_savepath"],
hyp_parameters["reg_msk_savepath"],
hyp_parameters["reg_ddf_savepath"],
reg_img_savepath_fullres,
reg_msk_savepath_fullres,
reg_ddf_savepath_fullres,
]:
os.makedirs(p, exist_ok=True)
# ========== Helper: load full-res data (same as original) ==========
def center_pad_to_cube(volume):
"""Pad volume to a cube using the max dimension, with symmetric (center) padding."""
max_dim = max(volume.shape[:3])
pad_width = []
for s in volume.shape[:3]:
total_pad = max_dim - s
pad_before = total_pad // 2
pad_after = total_pad - pad_before
pad_width.append((pad_before, pad_after))
for _ in range(volume.ndim - 3):
pad_width.append((0, 0))
return np.pad(volume, pad_width, mode="constant", constant_values=0)
def load_fullres_volume(key, ds):
"""Load original-resolution volume: axis reorder, clamp, normalize, center-pad to cube."""
volume = sitk.ReadImage(key)
volume = sitk.GetArrayFromImage(volume)
volume = reverse_axis_order(volume)
if volume.ndim == 4:
channel_ids = ds.get_channel_ids(key)
channel_id = channel_ids[0] if len(channel_ids) > 0 else 0
volume = volume[:, :, :, channel_id]
if ds.clamp_range is not None:
modality = ds.ALLdata_filtered[key].get("Modality", None)
if modality == "CT":
volume = np.clip(volume, ds.clamp_range[0], ds.clamp_range[1])
volume = ds.normalize(volume)
volume = center_pad_to_cube(volume)
return volume
def load_fullres_label(key, ds, label_key):
"""Load original-resolution label: axis reorder, center-pad to cube."""
label_path_dict = ds.ALLdata_filtered[key].get("Label_path", {})
task_labels = label_path_dict.get("segmentation", {})
if label_key not in task_labels:
return None
label = sitk.ReadImage(task_labels[label_key])
label = sitk.GetArrayFromImage(label)
label = reverse_axis_order(label)
if label.ndim > 3:
channel_ids = ds.get_channel_ids(key)
if len(channel_ids) != 0:
label = label[..., channel_ids]
label = center_pad_to_cube(label)
return label
# ========== Main inference loop ==========
keys = list(dataset.ALLdata_filtered.keys())
print("total num of images:", len(keys))
device = om.device
for e, key in enumerate(tqdm(keys)):
pid = e
print(f"Processing patient {pid}, image {e}, key: {key}")
# --- Load & standardize volume via OMorpher ---
fullres_vol = load_fullres_volume(key, dataset)
om.set_init_img(fullres_vol)
img = om._init_img # [1, 1, model_sz, model_sz, model_sz]
fullres_img_tensor = om._init_img_raw # [1, 1, D, H, W] full-res tensor
orig_sz = list(fullres_img_tensor.shape[2:])
print(f" Full-res padded shape: {orig_sz}")
# --- Load & standardize labels via OMorpher ---
masks_model = []
masks_fullres = []
for lk in label_keys:
lab = load_fullres_label(key, dataset, lk)
model_t, fullres_t = om._standardize_label(lab) # None → -1 placeholder
masks_model.append(model_t)
masks_fullres.append(fullres_t)
if masks_model:
mask = torch.cat(masks_model, dim=1) # [1, C_total, S, S, S]
fullres_msk_tensor = torch.cat(masks_fullres, dim=1) # [1, C_total, D, H, W]
else:
mask = None
fullres_msk_tensor = None
# --- Save target conditioning image (first subject) ---
if e <= 0:
target_img = img.clone().detach()
# --- Save original images at model resolution ---
image_original = img.cpu().numpy()
nib.save(
utils.converet_to_nibabel(image_original, ndims=hyp_parameters["ndims"]),
os.path.join(hyp_parameters["reg_img_savepath"],
utils.get_barcode([pid, e]) + ".nii.gz"),
)
if mask is not None:
mask_original = mask.cpu().numpy()
nib.save(
utils.converet_to_nibabel(mask_original, ndims=hyp_parameters["ndims"]),
os.path.join(hyp_parameters["reg_msk_savepath"],
utils.get_barcode([pid, e]) + "_GT.nii.gz"),
)
# --- Save original at full-res ---
nib.save(
utils.converet_to_nibabel(fullres_img_tensor, ndims=hyp_parameters["ndims"]),
os.path.join(reg_img_savepath_fullres,
utils.get_barcode([pid, e]) + ".nii.gz"),
)
if fullres_msk_tensor is not None:
nib.save(
utils.converet_to_nibabel(fullres_msk_tensor, ndims=hyp_parameters["ndims"]),
os.path.join(reg_msk_savepath_fullres,
utils.get_barcode([pid, e]) + "_GT.nii.gz"),
)
# --- Diffusion recovery via OMorpher ---
noise_step = hyp_parameters["start_noise_step"]
with torch.no_grad():
for im in range(1):
print(
f" Generating -> Subject-{pid}, Scan-{e} "
f'({im}/{hyp_parameters["aug_coe"]})',
end="\r",
)
# Set up OMorpher inputs
om.set_init_img(img)
om.set_cond_img(target_img.clone().detach())
# Run diffusion recovery
# T=[None, timesteps] in original means: no initial noise, full reverse diffusion
om.predict(
T=[None, hyp_parameters["timesteps"]],
proc_type=hyp_parameters["condition_type"],
)
ddf_comp = om.get_def()
# Reconstruct images at model resolution using OMorpher
img_rec = om.apply_def(img=img, ddf=ddf_comp, padding_mode="zeros")
# --- Save model-resolution results ---
denoise_imgs = img_rec.cpu().numpy()
nib.save(
utils.converet_to_nibabel(denoise_imgs, ndims=hyp_parameters["ndims"]),
os.path.join(
hyp_parameters["reg_img_savepath"],
utils.get_barcode([pid, e, im, noise_step]) + ".nii.gz",
),
)
if mask is not None:
msk_rec = om.apply_def(
img=mask, ddf=ddf_comp,
padding_mode="zeros", resample_mode="nearest",
)
denoise_msks = msk_rec.cpu().numpy()
nib.save(
utils.converet_to_nibabel(denoise_msks, ndims=hyp_parameters["ndims"]),
os.path.join(
hyp_parameters["reg_msk_savepath"],
utils.get_barcode([pid, e, im, noise_step]) + "_GT.nii.gz",
),
)
# --- Upscale DDF and apply at full resolution via OMorpher ---
img_rec_fullres = om.apply_def(
img=fullres_img_tensor, ddf=ddf_comp, padding_mode="border",
)
if fullres_msk_tensor is not None:
msk_rec_fullres = om.apply_def(
img=fullres_msk_tensor, ddf=ddf_comp,
padding_mode="zeros", resample_mode="nearest",
)
# Upscale DDF for saving
ddf_fullres = F.interpolate(
ddf_comp, size=orig_sz, mode="trilinear", align_corners=False,
)
# --- Save full-res results ---
nib.save(
utils.converet_to_nibabel(img_rec_fullres, ndims=hyp_parameters["ndims"]),
os.path.join(
reg_img_savepath_fullres,
utils.get_barcode([pid, e, im, noise_step]) + ".nii.gz",
),
)
if fullres_msk_tensor is not None:
nib.save(
utils.converet_to_nibabel(msk_rec_fullres, ndims=hyp_parameters["ndims"]),
os.path.join(
reg_msk_savepath_fullres,
utils.get_barcode([pid, e, im, noise_step]) + "_GT.nii.gz",
),
)
nib.save(
utils.converet_to_nibabel(ddf_fullres, ndims=hyp_parameters["ndims"]),
os.path.join(
reg_ddf_savepath_fullres,
utils.get_barcode([pid, e, im, noise_step]) + ".nii.gz",
),
)
if (im - hyp_parameters["start_noise_step"]) % 2 == 0:
noise_step = noise_step + hyp_parameters["noise_step"]
if e > 5:
break