Omini3D / OM_reg_flexres.py
maxmo2009's picture
Sync from local: code + epoch-110 checkpoint, clean README
2af0e94 verified
import torch
import torchvision
from torch import nn
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torch.optim import Adam
from torchvision.utils import make_grid
from Diffusion.diffuser import DeformDDPM
from Diffusion.networks import get_net, STN
from torchvision.transforms import Lambda
import random
import os
import utils
from Dataloader.dataloader0 import get_dataloader
from Dataloader.dataLoader import *
from torchvision.utils import save_image
from einops import rearrange, reduce, repeat
import numpy as np
import nibabel as nib
from tqdm import tqdm
import yaml
import argparse
import torch.nn.functional as F
import SimpleITK as sitk
from skimage.transform import resize
EPS = 10e-8
parser = argparse.ArgumentParser()
parser.add_argument(
"--config",
"-C",
help="Path for the config file",
type=str,
default="Config/config_om.yaml",
required=False,
)
args = parser.parse_args()
#=======================================================================================================================
# Load the YAML file into a dictionary
with open(args.config, 'r') as file:
hyp_parameters = yaml.safe_load(file)
print(hyp_parameters)
if not os.path.exists(hyp_parameters["aug_img_savepath"]):
os.makedirs(hyp_parameters["aug_img_savepath"])
if not os.path.exists(hyp_parameters["aug_msk_savepath"]):
os.makedirs(hyp_parameters["aug_msk_savepath"])
if not os.path.exists(hyp_parameters["aug_ddf_savepath"]):
os.makedirs(hyp_parameters["aug_ddf_savepath"])
print(hyp_parameters["aug_img_savepath"])
hyp_parameters['batchsize'] = 1
model_img_sz = hyp_parameters['img_size'] # e.g. 128
# =======================================================================================================================
# Dataset is used only for its filtering logic (to get the right set of keys + metadata).
# We bypass the DataLoader and load volumes directly to ensure deterministic center-padding
# that is identical between the 128^3 model input and the full-res volume.
label_keys = ['brain']
database = ['Brats2019']
dataset = OminiDataset_inference_w_all(
transform=None, min_crop_ratio=1.0, label_key=label_keys, database=database)
# =======================================================================================================================
epoch=f'{hyp_parameters["model_id_str"]}_{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}'
model_save_path = f'Models/{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}/'
model_save_path = os.path.join(model_save_path, str(epoch)+'.pth')
Net = get_net(hyp_parameters["net_name"])
Deformddpm = DeformDDPM(
network=Net(n_steps = hyp_parameters["timesteps"],
ndims = hyp_parameters["ndims"],
num_input_chn = hyp_parameters["num_input_chn"],
res = model_img_sz
),
n_steps = hyp_parameters["timesteps"],
image_chw = [hyp_parameters["num_input_chn"]] + [model_img_sz]*hyp_parameters["ndims"],
device = hyp_parameters["device"],
batch_size = hyp_parameters["batchsize"],
img_pad_mode = hyp_parameters["img_pad_mode"],
ddf_pad_mode = hyp_parameters["ddf_pad_mode"],
padding_mode = hyp_parameters["padding_mode"],
v_scale = hyp_parameters["v_scale"],
resample_mode = hyp_parameters["resample_mode"],
inf_mode = True,
)
Deformddpm.to(hyp_parameters["device"])
ddf_stn = STN(
img_sz = model_img_sz,
ndims = hyp_parameters["ndims"],
padding_mode = hyp_parameters['padding_mode'],
device = hyp_parameters["device"],
)
ddf_stn.to(hyp_parameters["device"])
print("Loading model from:", model_save_path)
checkpoint = torch.load(model_save_path, map_location='cpu')
Deformddpm.load_state_dict(checkpoint['model_state_dict'])
Deformddpm.eval()
# Full-res output directories (append _fullres to the standard paths)
reg_img_savepath_fullres = hyp_parameters['reg_img_savepath'].rstrip('/') + '_fullres/'
reg_msk_savepath_fullres = hyp_parameters['reg_msk_savepath'].rstrip('/') + '_fullres/'
reg_ddf_savepath_fullres = hyp_parameters['reg_ddf_savepath'].rstrip('/') + '_fullres/'
os.makedirs(hyp_parameters['reg_img_savepath'], exist_ok=True)
os.makedirs(hyp_parameters['reg_msk_savepath'], exist_ok=True)
os.makedirs(hyp_parameters['reg_ddf_savepath'], exist_ok=True)
os.makedirs(reg_img_savepath_fullres, exist_ok=True)
os.makedirs(reg_msk_savepath_fullres, exist_ok=True)
os.makedirs(reg_ddf_savepath_fullres, exist_ok=True)
# ========== Helper functions ==========
def center_pad_to_cube(volume):
"""Pad volume to a cube using the max dimension, with symmetric (center) padding."""
max_dim = max(volume.shape[:3])
pad_width = []
for s in volume.shape[:3]:
total_pad = max_dim - s
pad_before = total_pad // 2
pad_after = total_pad - pad_before
pad_width.append((pad_before, pad_after))
# Handle extra dims (e.g., multi-channel labels)
for _ in range(volume.ndim - 3):
pad_width.append((0, 0))
return np.pad(volume, pad_width, mode='constant', constant_values=0)
def load_fullres_volume(key, ds):
"""Load original-resolution volume: axis reorder, clamp, normalize, center-pad to cube."""
volume = sitk.ReadImage(key)
volume = sitk.GetArrayFromImage(volume)
volume = reverse_axis_order(volume)
if volume.ndim == 4:
channel_ids = ds.get_channel_ids(key)
channel_id = channel_ids[0] if len(channel_ids) > 0 else 0
volume = volume[:, :, :, channel_id]
# CT clamping
if ds.clamp_range is not None:
modality = ds.ALLdata_filtered[key].get("Modality", None)
if modality == "CT":
volume = np.clip(volume, ds.clamp_range[0], ds.clamp_range[1])
volume = ds.normalize(volume)
volume = center_pad_to_cube(volume)
return volume # shape: [D, D, D] (cubic)
def load_fullres_label(key, ds, label_key):
"""Load original-resolution label: axis reorder, center-pad to cube (no resize)."""
label_path_dict = ds.ALLdata_filtered[key].get('Label_path', {})
task_labels = label_path_dict.get('segmentation', {})
if label_key not in task_labels:
return None
label = sitk.ReadImage(task_labels[label_key])
label = sitk.GetArrayFromImage(label)
label = reverse_axis_order(label)
if label.ndim > 3:
channel_ids = ds.get_channel_ids(key)
if len(channel_ids) != 0:
label = label[..., channel_ids]
label = center_pad_to_cube(label)
return label
def apply_ddf(volume_tensor, ddf, padding_mode='border', resample_mode='bilinear'):
"""Apply DDF to volume tensor at any resolution.
The DDF stores fractional displacements (value * max_sz = voxel displacement).
When the DDF is spatially upscaled via trilinear interpolation from model resolution
to full resolution, the fractional values remain correct — we use the new spatial
size as max_sz, which correctly scales the voxel displacement proportionally.
"""
device = ddf.device
ndims = 3
img_sz = list(volume_tensor.shape[2:])
max_sz = torch.reshape(
torch.tensor(img_sz, dtype=torch.float32, device=device),
[1, ndims] + [1] * ndims)
ref_grid = torch.reshape(
torch.stack(torch.meshgrid(
[torch.arange(s, device=device) for s in img_sz], indexing='ij'), 0),
[1, ndims] + img_sz)
img_shape = torch.reshape(
torch.tensor([(s - 1) / 2. for s in img_sz], dtype=torch.float32, device=device),
[1] + [1] * ndims + [ndims])
grid = torch.flip(
(ddf * max_sz + ref_grid).permute(
[0] + list(range(2, 2 + ndims)) + [1]) / img_shape - 1,
dims=[-1])
return F.grid_sample(volume_tensor, grid.float(), mode=resample_mode,
padding_mode=padding_mode, align_corners=True)
# ========== Main inference loop ==========
keys = list(dataset.ALLdata_filtered.keys())
print("total num of images:", len(keys))
for e, key in enumerate(tqdm(keys)):
pid = e
print(f'Processing patient {pid}, image {e}, key: {key}')
# --- Load full-resolution volume (center-padded to cube) ---
fullres_vol = load_fullres_volume(key, dataset)
orig_sz = list(fullres_vol.shape) # e.g. [240, 240, 240]
print(f" Full-res padded shape: {orig_sz}")
# --- Resize to model resolution for inference ---
vol_model = resize(fullres_vol, [model_img_sz] * 3,
anti_aliasing=True, preserve_range=True)
img = torch.tensor(vol_model[None, None, :, :, :],
dtype=torch.float32, device=hyp_parameters["device"])
# --- Load full-res labels and resize to model resolution ---
fullres_labels = {}
for lk in label_keys:
lab = load_fullres_label(key, dataset, lk)
if lab is not None:
fullres_labels[lk] = lab
# Build mask at model resolution (128^3)
label_arrays_model = []
label_arrays_fullres = []
for lk in label_keys:
if lk in fullres_labels:
lab = fullres_labels[lk]
lab_model = resize(lab, [model_img_sz] * 3,
anti_aliasing=False, preserve_range=True, order=0)
if lab_model.ndim == 3:
lab_model = lab_model[None, :, :, :]
elif lab_model.ndim > 3:
lab_model = np.transpose(lab_model, (3, 0, 1, 2))
label_arrays_model.append(lab_model)
if lab.ndim == 3:
lab = lab[None, :, :, :]
elif lab.ndim > 3:
lab = np.transpose(lab, (3, 0, 1, 2))
label_arrays_fullres.append(lab)
else:
label_arrays_model.append(np.full([1] + [model_img_sz] * 3, -1))
label_arrays_fullres.append(np.full([1] + orig_sz, -1))
if len(label_arrays_model) > 0:
mask_model_np = np.concatenate(label_arrays_model, axis=0)
mask = torch.tensor(mask_model_np[None], dtype=torch.float32,
device=hyp_parameters["device"])
fullres_msk_np = np.concatenate(label_arrays_fullres, axis=0)
fullres_msk_tensor = torch.tensor(fullres_msk_np[None], dtype=torch.float32,
device=hyp_parameters["device"])
else:
mask = None
fullres_msk_np = None
fullres_msk_tensor = None
# Build full-res image tensor
fullres_img_tensor = torch.tensor(fullres_vol[None, None, :, :, :],
dtype=torch.float32,
device=hyp_parameters["device"])
# --- Save target conditioning image (first subject) ---
if e <= 0:
target_img = img.clone().detach()
# --- Save original images at 128^3 ---
image_original = img.cpu().numpy()
nib.save(utils.converet_to_nibabel(image_original, ndims=hyp_parameters["ndims"]),
os.path.join(hyp_parameters['reg_img_savepath'],
utils.get_barcode([pid, e]) + '.nii.gz'))
if mask is not None:
mask_original = mask.cpu().numpy()
nib.save(utils.converet_to_nibabel(mask_original, ndims=hyp_parameters["ndims"]),
os.path.join(hyp_parameters['reg_msk_savepath'],
utils.get_barcode([pid, e]) + '_GT.nii.gz'))
# --- Save original at full-res ---
# fullres_vol is [D,D,D], wrap as [1,1,D,D,D] for converet_to_nibabel
nib.save(utils.converet_to_nibabel(fullres_vol[None, None], ndims=hyp_parameters["ndims"]),
os.path.join(reg_img_savepath_fullres,
utils.get_barcode([pid, e]) + '.nii.gz'))
if fullres_msk_np is not None:
# fullres_msk_np is [C,D,D,D], wrap as [1,C,D,D,D]
nib.save(utils.converet_to_nibabel(fullres_msk_np[None], ndims=hyp_parameters["ndims"]),
os.path.join(reg_msk_savepath_fullres,
utils.get_barcode([pid, e]) + '_GT.nii.gz'))
# --- Diffusion recovery at model resolution ---
noise_step = hyp_parameters["start_noise_step"]
with torch.no_grad():
for im in range(1):
print(f' Generating -> Subject-{pid}, Scan-{e} ({im}/{hyp_parameters["aug_coe"]})', end='\r')
[ddf_comp, ddf_rand], [img_rec, img_diff, img_save], [msk_rec, msk_diff, msk_save] = \
Deformddpm.diff_recover(
img_org=img,
cond_imgs=target_img.clone().detach(),
msk_org=mask,
T=[None, hyp_parameters["timesteps"]],
v_scale=hyp_parameters["v_scale"],
t_save=None,
proc_type=hyp_parameters["condition_type"])
# --- Save 128^3 results (same as OM_reg.py) ---
denoise_imgs = img_rec.cpu().numpy()
noisy_imgs_np = img_diff.cpu().numpy()
nib.save(utils.converet_to_nibabel(denoise_imgs, ndims=hyp_parameters["ndims"]),
os.path.join(hyp_parameters['reg_img_savepath'],
utils.get_barcode([pid, e, im, noise_step]) + '.nii.gz'))
nib.save(utils.converet_to_nibabel(noisy_imgs_np, ndims=hyp_parameters["ndims"]),
os.path.join(hyp_parameters['reg_img_savepath'],
utils.get_barcode([pid, e, im, noise_step],
header=['Patient', 'Slice', 'NoiseImg', 'NoiseStep']) + '.nii.gz'))
if msk_rec is not None:
denoise_msks = msk_rec.cpu().numpy()
noisy_msks_np = msk_diff.cpu().numpy()
nib.save(utils.converet_to_nibabel(denoise_msks, ndims=hyp_parameters["ndims"]),
os.path.join(hyp_parameters['reg_msk_savepath'],
utils.get_barcode([pid, e, im, noise_step]) + '_GT.nii.gz'))
nib.save(utils.converet_to_nibabel(noisy_msks_np, ndims=hyp_parameters["ndims"]),
os.path.join(hyp_parameters['reg_msk_savepath'],
utils.get_barcode([pid, e, im, noise_step],
header=['Patient', 'Slice', 'NoiseImg', 'NoiseStep']) + '_GT.nii.gz'))
# --- Upscale DDFs to original resolution ---
ddf_fullres = F.interpolate(ddf_comp, size=orig_sz,
mode='trilinear', align_corners=False)
ddf_rand_fullres = F.interpolate(ddf_rand, size=orig_sz,
mode='trilinear', align_corners=False)
# --- Apply DDFs at original resolution ---
img_rec_fullres = apply_ddf(fullres_img_tensor, ddf_fullres,
padding_mode='border')
img_noisy_fullres = apply_ddf(fullres_img_tensor, ddf_rand_fullres,
padding_mode='border')
if fullres_msk_tensor is not None:
msk_rec_fullres = apply_ddf(fullres_msk_tensor, ddf_fullres,
padding_mode='zeros', resample_mode='nearest')
msk_noisy_fullres = apply_ddf(fullres_msk_tensor, ddf_rand_fullres,
padding_mode='zeros', resample_mode='nearest')
# --- Save full-res results ---
nib.save(utils.converet_to_nibabel(img_rec_fullres, ndims=hyp_parameters["ndims"]),
os.path.join(reg_img_savepath_fullres,
utils.get_barcode([pid, e, im, noise_step]) + '.nii.gz'))
nib.save(utils.converet_to_nibabel(img_noisy_fullres, ndims=hyp_parameters["ndims"]),
os.path.join(reg_img_savepath_fullres,
utils.get_barcode([pid, e, im, noise_step],
header=['Patient', 'Slice', 'NoiseImg', 'NoiseStep']) + '.nii.gz'))
if fullres_msk_tensor is not None:
nib.save(utils.converet_to_nibabel(msk_rec_fullres, ndims=hyp_parameters["ndims"]),
os.path.join(reg_msk_savepath_fullres,
utils.get_barcode([pid, e, im, noise_step]) + '_GT.nii.gz'))
nib.save(utils.converet_to_nibabel(msk_noisy_fullres, ndims=hyp_parameters["ndims"]),
os.path.join(reg_msk_savepath_fullres,
utils.get_barcode([pid, e, im, noise_step],
header=['Patient', 'Slice', 'NoiseImg', 'NoiseStep']) + '_GT.nii.gz'))
# Save full-res DDF (converet_to_nibabel handles multi-channel → channel-last)
nib.save(utils.converet_to_nibabel(ddf_fullres, ndims=hyp_parameters["ndims"]),
os.path.join(reg_ddf_savepath_fullres,
utils.get_barcode([pid, e, im, noise_step]) + '.nii.gz'))
if (im - hyp_parameters["start_noise_step"]) % 2 == 0:
noise_step = noise_step + hyp_parameters["noise_step"]
if e > 5:
break