Omini3D / Scripts /OM_reg_unpair.py
maxmo2009's picture
Sync from local: code + epoch-110 checkpoint, clean README
2af0e94 verified
"""
OM_reg_unpair.py — Unpaired all-to-all registration using OMorpher.
Registers every subject to every other subject in a nested loop.
Output naming uses Tgt/Src prefixes instead of the Patient/Slice barcode.
Computes DSC, ASD, HD for organ labels (excludes tumour/lesion labels)
and saves per-pair tables and summary statistics as CSVs.
Usage:
python Scripts/OM_reg_unpair.py -C Config/config_om.yaml
"""
import os
import sys
# Add project root to path so imports work from Scripts/
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import csv
import numpy as np
import torch
import torch.nn.functional as F
import nibabel as nib
import yaml
import SimpleITK as sitk
from scipy.ndimage import distance_transform_edt, binary_erosion
from tqdm import tqdm
import utils
from Dataloader.dataLoader import OminiDataset_inference_w_all, reverse_axis_order
from OMorpher import OMorpher
# ========== CLI ==========
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--config", "-C",
help="Path for the config file",
type=str,
default="Config/config_om.yaml",
required=False,
)
parser.add_argument(
"--max-samples", "-N",
help="Max number of subjects to include (0 = all)",
type=int,
default=0,
)
args = parser.parse_args()
# ========== Config ==========
with open(args.config, "r") as file:
hyp_parameters = yaml.safe_load(file)
print(hyp_parameters)
hyp_parameters["batchsize"] = 1
model_img_sz = hyp_parameters["img_size"]
timesteps = hyp_parameters["timesteps"]
condition_type = hyp_parameters["condition_type"]
ndims = hyp_parameters["ndims"]
# ========== Dataset ==========
label_keys = ["brain"]
database = ["Brats2019"]
dataset = OminiDataset_inference_w_all(
transform=None, min_crop_ratio=1.0, label_key=label_keys, database=database,
)
# ========== OMorpher setup ==========
epoch = f'{hyp_parameters["model_id_str"]}_{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}'
model_save_path = os.path.join(
f'Models/{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}/',
str(epoch) + ".pth",
)
print("Loading model from:", model_save_path)
om = OMorpher(
config=hyp_parameters,
checkpoint_path=model_save_path,
device=str(hyp_parameters.get("device", "cpu")),
)
print(om)
# ========== Output directories ==========
reg_img_savepath = hyp_parameters["reg_img_savepath"]
reg_msk_savepath = hyp_parameters["reg_msk_savepath"]
reg_ddf_savepath = hyp_parameters["reg_ddf_savepath"]
reg_img_savepath_fullres = reg_img_savepath.rstrip("/") + "_fullres/"
reg_msk_savepath_fullres = reg_msk_savepath.rstrip("/") + "_fullres/"
reg_ddf_savepath_fullres = reg_ddf_savepath.rstrip("/") + "_fullres/"
eval_dir = os.path.join(reg_img_savepath, "..", "eval")
for p in [
reg_img_savepath, reg_msk_savepath, reg_ddf_savepath,
reg_img_savepath_fullres, reg_msk_savepath_fullres, reg_ddf_savepath_fullres,
eval_dir,
]:
os.makedirs(p, exist_ok=True)
# ========== Settings ==========
skip_self = True # skip pairs where source == target
# Labels that are NOT organs — excluded from metric evaluation.
# BraTS "brain" is actually a tumour segmentation (non-enhancing tumour, edema,
# enhancing tumour), not a whole-brain mask, so it must be excluded.
EXCLUDE_LABELS = {
"brain", # BraTS tumour segmentation
"tumor", # PSMA-CT / PSMA-FDG tumour
"tumour",
"noisy", # Kaggle OSIC artefact mask
}
# Any label containing these substrings is also excluded
EXCLUDE_SUBSTRINGS = {"lesion", "tumor", "tumour"}
# ========== Helper functions ==========
def center_pad_to_cube(volume):
"""Pad volume to a cube using the max dimension, with symmetric (center) padding."""
max_dim = max(volume.shape[:3])
pad_width = []
for s in volume.shape[:3]:
total_pad = max_dim - s
pad_before = total_pad // 2
pad_after = total_pad - pad_before
pad_width.append((pad_before, pad_after))
for _ in range(volume.ndim - 3):
pad_width.append((0, 0))
return np.pad(volume, pad_width, mode="constant", constant_values=0)
def load_fullres_volume(key, ds):
"""Load original-resolution volume: axis reorder, clamp, normalize, center-pad to cube."""
volume = sitk.ReadImage(key)
volume = sitk.GetArrayFromImage(volume)
volume = reverse_axis_order(volume)
if volume.ndim == 4:
channel_ids = ds.get_channel_ids(key)
channel_id = channel_ids[0] if len(channel_ids) > 0 else 0
volume = volume[:, :, :, channel_id]
if ds.clamp_range is not None:
modality = ds.ALLdata_filtered[key].get("Modality", None)
if modality == "CT":
volume = np.clip(volume, ds.clamp_range[0], ds.clamp_range[1])
volume = ds.normalize(volume)
volume = center_pad_to_cube(volume)
return volume
def load_fullres_label(key, ds, label_key):
"""Load original-resolution label: axis reorder, center-pad to cube."""
label_path_dict = ds.ALLdata_filtered[key].get("Label_path", {})
task_labels = label_path_dict.get("segmentation", {})
if label_key not in task_labels:
return None
label = sitk.ReadImage(task_labels[label_key])
label = sitk.GetArrayFromImage(label)
label = reverse_axis_order(label)
if label.ndim > 3:
channel_ids = ds.get_channel_ids(key)
if len(channel_ids) != 0:
label = label[..., channel_ids]
label = center_pad_to_cube(label)
return label
def get_volume_name(key):
"""Extract a short name from a NIfTI file path."""
name = os.path.basename(key)
for ext in [".nii.gz", ".nii"]:
if name.endswith(ext):
name = name[: -len(ext)]
break
return name
def is_organ_label(label_key):
"""Return True if label_key is an organ (not tumour/lesion)."""
lk_lower = label_key.lower()
if lk_lower in EXCLUDE_LABELS:
return False
return not any(kw in lk_lower for kw in EXCLUDE_SUBSTRINGS)
# ---------- Evaluation metrics ----------
def _surface_distances(pred, gt):
"""Compute directed surface distances between two binary masks.
Returns (dist_pred_to_gt, dist_gt_to_pred) arrays, or (None, None)
if either mask is empty or has no extractable surface.
"""
pred_bool = pred > 0.5
gt_bool = gt > 0.5
if not np.any(pred_bool) or not np.any(gt_bool):
return None, None
# Extract surface voxels via erosion
struct = None # default 3x3(x3) cross connectivity
pred_surface = pred_bool ^ binary_erosion(pred_bool, structure=struct)
gt_surface = gt_bool ^ binary_erosion(gt_bool, structure=struct)
# Fallback: single-voxel regions lose their surface after erosion
if not np.any(pred_surface):
pred_surface = pred_bool
if not np.any(gt_surface):
gt_surface = gt_bool
dt_gt = distance_transform_edt(~gt_surface)
dt_pred = distance_transform_edt(~pred_surface)
return dt_gt[pred_surface], dt_pred[gt_surface]
def compute_dsc(pred, gt):
"""Dice Similarity Coefficient."""
pred_bool = pred > 0.5
gt_bool = gt > 0.5
intersection = np.sum(pred_bool & gt_bool)
denom = np.sum(pred_bool) + np.sum(gt_bool)
if denom == 0:
return 1.0 # both empty — perfect agreement
return 2.0 * float(intersection) / float(denom)
def compute_asd(pred, gt):
"""Average (symmetric) Surface Distance."""
d1, d2 = _surface_distances(pred, gt)
if d1 is None:
return float("nan")
return (np.mean(d1) + np.mean(d2)) / 2.0
def compute_hd(pred, gt):
"""Hausdorff Distance (maximum of directed HDs)."""
d1, d2 = _surface_distances(pred, gt)
if d1 is None:
return float("nan")
return float(max(np.max(d1), np.max(d2)))
def compute_negdetj_pct(ddf, ndims=3):
"""Percent of voxels with negative Jacobian determinant.
Args:
ddf: displacement field tensor [1, ndims, ...] or numpy array.
ndims: 2 or 3.
Returns:
Percentage of voxels where det(Jacobian) < 0.
"""
if isinstance(ddf, torch.Tensor):
ddf = ddf.detach().cpu().numpy()
if ddf.ndim == ndims + 2:
ddf = ddf[0] # remove batch dim -> [C, ...]
if ndims == 3:
dux_dx = np.diff(ddf[0], axis=0, append=ddf[0, -1:, :, :])
duy_dx = np.diff(ddf[1], axis=0, append=ddf[1, -1:, :, :])
duz_dx = np.diff(ddf[2], axis=0, append=ddf[2, -1:, :, :])
dux_dy = np.diff(ddf[0], axis=1, append=ddf[0, :, -1:, :])
duy_dy = np.diff(ddf[1], axis=1, append=ddf[1, :, -1:, :])
duz_dy = np.diff(ddf[2], axis=1, append=ddf[2, :, -1:, :])
dux_dz = np.diff(ddf[0], axis=2, append=ddf[0, :, :, -1:])
duy_dz = np.diff(ddf[1], axis=2, append=ddf[1, :, :, -1:])
duz_dz = np.diff(ddf[2], axis=2, append=ddf[2, :, :, -1:])
j11 = 1.0 + dux_dx; j12 = dux_dy; j13 = dux_dz
j21 = duy_dx; j22 = 1.0 + duy_dy; j23 = duy_dz
j31 = duz_dx; j32 = duz_dy; j33 = 1.0 + duz_dz
detj = (
j11 * (j22 * j33 - j23 * j32)
- j12 * (j21 * j33 - j23 * j31)
+ j13 * (j21 * j32 - j22 * j31)
)
elif ndims == 2:
dux_dx = np.diff(ddf[0], axis=0, append=ddf[0, -1:, :])
duy_dx = np.diff(ddf[1], axis=0, append=ddf[1, -1:, :])
dux_dy = np.diff(ddf[0], axis=1, append=ddf[0, :, -1:])
duy_dy = np.diff(ddf[1], axis=1, append=ddf[1, :, -1:])
detj = (1.0 + dux_dx) * (1.0 + duy_dy) - dux_dy * duy_dx
else:
raise ValueError(f"Unsupported ndims={ndims}")
n_neg = np.sum(detj < 0)
n_total = detj.size
return 100.0 * float(n_neg) / float(n_total)
# ========== Pre-load all subjects ==========
keys = list(dataset.ALLdata_filtered.keys())
if args.max_samples > 0:
keys = keys[: args.max_samples]
print(f"Total subjects: {len(keys)} (max_samples={args.max_samples or 'all'})")
subjects = []
for key in tqdm(keys, desc="Loading subjects"):
fullres_vol = load_fullres_volume(key, dataset)
om.set_init_img(fullres_vol)
img_model = om._init_img.clone()
img_fullres = om._init_img_raw.clone()
orig_sz = list(img_fullres.shape[2:])
masks_model = []
masks_fullres = []
for lk in label_keys:
lab = load_fullres_label(key, dataset, lk)
model_t, fullres_t = om._standardize_label(lab)
masks_model.append(model_t)
masks_fullres.append(fullres_t)
if masks_model:
mask_model = torch.cat(masks_model, dim=1)
mask_fullres = torch.cat(masks_fullres, dim=1)
else:
mask_model = None
mask_fullres = None
subjects.append({
"key": key,
"img_model": img_model,
"img_fullres": img_fullres,
"mask_model": mask_model,
"mask_fullres": mask_fullres,
"orig_sz": orig_sz,
})
print(f"Loaded {len(subjects)} subjects into memory.")
# ========== Prepare evaluation structures ==========
vol_names = [get_volume_name(subj["key"]) for subj in subjects]
# Disambiguate duplicate basenames by appending index
_seen = {}
for i, vn in enumerate(vol_names):
_seen.setdefault(vn, []).append(i)
for vn, indices in _seen.items():
if len(indices) > 1:
for idx in indices:
vol_names[idx] = f"{vn}_{idx}"
organ_label_indices = [] # (channel_index, label_key) for organ labels only
for c, lk in enumerate(label_keys):
if is_organ_label(lk):
organ_label_indices.append((c, lk))
if organ_label_indices:
print(f"Organ labels for evaluation: {[lk for _, lk in organ_label_indices]}")
else:
print("No organ labels found — skipping evaluation metrics.")
# metrics[label_key][metric_name][(t, s)] = value
metrics = {
lk: {"dsc": {}, "asd": {}, "hd": {}}
for _, lk in organ_label_indices
}
# Per-pair DDF quality metric
negdetj_pct = {} # (t, s) -> percentage of negative Jacobian determinant
# ========== All-to-all registration ==========
with torch.no_grad():
for t, tgt in enumerate(tqdm(subjects, desc="Targets")):
tgt_tag = f"Tgt{t:04d}"
# --- Save target original at model resolution ---
nib.save(
utils.converet_to_nibabel(tgt["img_model"], ndims=ndims),
os.path.join(reg_img_savepath, f"{tgt_tag}_ORG.nii.gz"),
)
if tgt["mask_model"] is not None:
nib.save(
utils.converet_to_nibabel(tgt["mask_model"], ndims=ndims),
os.path.join(reg_msk_savepath, f"{tgt_tag}_ORG_GT.nii.gz"),
)
# --- Save target original at full resolution ---
nib.save(
utils.converet_to_nibabel(tgt["img_fullres"], ndims=ndims),
os.path.join(reg_img_savepath_fullres, f"{tgt_tag}_ORG.nii.gz"),
)
if tgt["mask_fullres"] is not None:
nib.save(
utils.converet_to_nibabel(tgt["mask_fullres"], ndims=ndims),
os.path.join(reg_msk_savepath_fullres, f"{tgt_tag}_ORG_GT.nii.gz"),
)
# --- Inner loop: register each source to this target ---
for s, src in enumerate(subjects):
if skip_self and s == t:
continue
pair_tag = f"Tgt{t:04d}_Src{s:04d}"
print(f" Registering {pair_tag}")
om.set_init_img(src["img_model"])
om.set_cond_img(tgt["img_model"].clone().detach())
om.predict(
T=[None, timesteps],
proc_type=condition_type,
)
ddf_comp = om.get_def()
# --- DDF quality: percent negative Jacobian determinant ---
neg_pct = compute_negdetj_pct(ddf_comp, ndims=ndims)
negdetj_pct[(t, s)] = neg_pct
print(f" %|J|<0 = {neg_pct:.4f}%")
# --- Model-resolution registered image ---
img_rec = om.apply_def(
img=src["img_model"], ddf=ddf_comp, padding_mode="zeros",
)
nib.save(
utils.converet_to_nibabel(img_rec, ndims=ndims),
os.path.join(reg_img_savepath, f"{pair_tag}.nii.gz"),
)
# --- Model-resolution registered mask ---
msk_rec = None
if src["mask_model"] is not None:
msk_rec = om.apply_def(
img=src["mask_model"], ddf=ddf_comp,
padding_mode="zeros", resample_mode="nearest",
)
nib.save(
utils.converet_to_nibabel(msk_rec, ndims=ndims),
os.path.join(reg_msk_savepath, f"{pair_tag}_GT.nii.gz"),
)
# --- Model-resolution DDF ---
nib.save(
utils.converet_to_nibabel(ddf_comp, ndims=ndims),
os.path.join(reg_ddf_savepath, f"{pair_tag}.nii.gz"),
)
# --- Full-resolution registered image ---
img_rec_fullres = om.apply_def(
img=src["img_fullres"], ddf=ddf_comp, padding_mode="border",
)
nib.save(
utils.converet_to_nibabel(img_rec_fullres, ndims=ndims),
os.path.join(reg_img_savepath_fullres, f"{pair_tag}.nii.gz"),
)
# --- Full-resolution registered mask ---
msk_rec_fullres = None
if src["mask_fullres"] is not None:
msk_rec_fullres = om.apply_def(
img=src["mask_fullres"], ddf=ddf_comp,
padding_mode="zeros", resample_mode="nearest",
)
nib.save(
utils.converet_to_nibabel(msk_rec_fullres, ndims=ndims),
os.path.join(reg_msk_savepath_fullres, f"{pair_tag}_GT.nii.gz"),
)
# --- Full-resolution DDF ---
ddf_fullres = F.interpolate(
ddf_comp, size=src["orig_sz"], mode="trilinear", align_corners=False,
)
nib.save(
utils.converet_to_nibabel(ddf_fullres, ndims=ndims),
os.path.join(reg_ddf_savepath_fullres, f"{pair_tag}.nii.gz"),
)
# --- Evaluation metrics (full-res organ labels) ---
if (
organ_label_indices
and msk_rec_fullres is not None
and tgt["mask_fullres"] is not None
):
for c, lk in organ_label_indices:
tgt_mask_np = tgt["mask_fullres"][0, c].cpu().numpy()
reg_mask_np = msk_rec_fullres[0, c].cpu().numpy()
# Skip placeholder masks (fill_value = -1)
if np.all(tgt_mask_np < 0) or np.all(reg_mask_np < 0):
continue
dsc_val = compute_dsc(reg_mask_np, tgt_mask_np)
asd_val = compute_asd(reg_mask_np, tgt_mask_np)
hd_val = compute_hd(reg_mask_np, tgt_mask_np)
metrics[lk]["dsc"][(t, s)] = dsc_val
metrics[lk]["asd"][(t, s)] = asd_val
metrics[lk]["hd"][(t, s)] = hd_val
print(
f" [{lk}] DSC={dsc_val:.4f} "
f"ASD={asd_val:.2f} HD={hd_val:.2f}"
)
print("All-to-all registration complete.")
# ========== Write evaluation CSVs ==========
n_subj = len(subjects)
# --- %|J|<0 matrix CSV ---
negdetj_csv_path = os.path.join(eval_dir, "negdetj_pct.csv")
with open(negdetj_csv_path, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(["target \\ source"] + vol_names)
for t_idx in range(n_subj):
row = [vol_names[t_idx]]
for s_idx in range(n_subj):
val = negdetj_pct.get((t_idx, s_idx))
if val is None:
row.append("")
else:
row.append(f"{val:.6f}")
writer.writerow(row)
print(f"Saved {negdetj_csv_path}")
for c, lk in organ_label_indices:
# Use label prefix only when there are multiple organ labels
prefix = f"{lk}_" if len(organ_label_indices) > 1 else ""
for metric_name in ["dsc", "asd", "hd"]:
csv_path = os.path.join(eval_dir, f"{prefix}{metric_name}.csv")
with open(csv_path, "w", newline="") as f:
writer = csv.writer(f)
# Header row: empty corner cell + source volume names
writer.writerow(["target \\ source"] + vol_names)
for t in range(n_subj):
row = [vol_names[t]]
for s in range(n_subj):
val = metrics[lk][metric_name].get((t, s))
if val is None:
row.append("") # self-pair or missing
elif np.isnan(val):
row.append("NaN")
else:
row.append(f"{val:.6f}")
writer.writerow(row)
print(f"Saved {csv_path}")
# --- Overall summary ---
overall_path = os.path.join(eval_dir, "overall.csv")
with open(overall_path, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(["label", "metric", "mean", "std", "n_pairs"])
# %|J|<0 summary (not per-label)
negdetj_vals = [v for v in negdetj_pct.values() if not np.isnan(v)]
if negdetj_vals:
writer.writerow([
"ALL", "%|J|<0",
f"{np.mean(negdetj_vals):.6f}", f"{np.std(negdetj_vals):.6f}",
len(negdetj_vals),
])
for _, lk in organ_label_indices:
for metric_name in ["dsc", "asd", "hd"]:
vals = [
v for v in metrics[lk][metric_name].values()
if not np.isnan(v)
]
if vals:
mean_val = np.mean(vals)
std_val = np.std(vals)
n_pairs = len(vals)
else:
mean_val = float("nan")
std_val = float("nan")
n_pairs = 0
writer.writerow([
lk,
metric_name.upper(),
f"{mean_val:.6f}" if not np.isnan(mean_val) else "NaN",
f"{std_val:.6f}" if not np.isnan(std_val) else "NaN",
n_pairs,
])
print(f"Saved {overall_path}")