| """
|
| OM_reg_unpair.py — Unpaired all-to-all registration using OMorpher.
|
|
|
| Registers every subject to every other subject in a nested loop.
|
| Output naming uses Tgt/Src prefixes instead of the Patient/Slice barcode.
|
| Computes DSC, ASD, HD for organ labels (excludes tumour/lesion labels)
|
| and saves per-pair tables and summary statistics as CSVs.
|
|
|
| Usage:
|
| python Scripts/OM_reg_unpair.py -C Config/config_om.yaml
|
| """
|
|
|
| import os
|
| import sys
|
|
|
|
|
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
| import csv
|
| import numpy as np
|
| import torch
|
| import torch.nn.functional as F
|
| import nibabel as nib
|
| import yaml
|
| import SimpleITK as sitk
|
| from scipy.ndimage import distance_transform_edt, binary_erosion
|
| from tqdm import tqdm
|
|
|
| import utils
|
| from Dataloader.dataLoader import OminiDataset_inference_w_all, reverse_axis_order
|
| from OMorpher import OMorpher
|
|
|
|
|
|
|
| import argparse
|
|
|
| parser = argparse.ArgumentParser()
|
| parser.add_argument(
|
| "--config", "-C",
|
| help="Path for the config file",
|
| type=str,
|
| default="Config/config_om.yaml",
|
| required=False,
|
| )
|
| parser.add_argument(
|
| "--max-samples", "-N",
|
| help="Max number of subjects to include (0 = all)",
|
| type=int,
|
| default=0,
|
| )
|
| args = parser.parse_args()
|
|
|
|
|
|
|
| with open(args.config, "r") as file:
|
| hyp_parameters = yaml.safe_load(file)
|
| print(hyp_parameters)
|
|
|
| hyp_parameters["batchsize"] = 1
|
| model_img_sz = hyp_parameters["img_size"]
|
| timesteps = hyp_parameters["timesteps"]
|
| condition_type = hyp_parameters["condition_type"]
|
| ndims = hyp_parameters["ndims"]
|
|
|
|
|
|
|
| label_keys = ["brain"]
|
| database = ["Brats2019"]
|
|
|
| dataset = OminiDataset_inference_w_all(
|
| transform=None, min_crop_ratio=1.0, label_key=label_keys, database=database,
|
| )
|
|
|
|
|
|
|
| epoch = f'{hyp_parameters["model_id_str"]}_{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}'
|
| model_save_path = os.path.join(
|
| f'Models/{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}/',
|
| str(epoch) + ".pth",
|
| )
|
| print("Loading model from:", model_save_path)
|
|
|
| om = OMorpher(
|
| config=hyp_parameters,
|
| checkpoint_path=model_save_path,
|
| device=str(hyp_parameters.get("device", "cpu")),
|
| )
|
| print(om)
|
|
|
|
|
|
|
| reg_img_savepath = hyp_parameters["reg_img_savepath"]
|
| reg_msk_savepath = hyp_parameters["reg_msk_savepath"]
|
| reg_ddf_savepath = hyp_parameters["reg_ddf_savepath"]
|
|
|
| reg_img_savepath_fullres = reg_img_savepath.rstrip("/") + "_fullres/"
|
| reg_msk_savepath_fullres = reg_msk_savepath.rstrip("/") + "_fullres/"
|
| reg_ddf_savepath_fullres = reg_ddf_savepath.rstrip("/") + "_fullres/"
|
|
|
| eval_dir = os.path.join(reg_img_savepath, "..", "eval")
|
|
|
| for p in [
|
| reg_img_savepath, reg_msk_savepath, reg_ddf_savepath,
|
| reg_img_savepath_fullres, reg_msk_savepath_fullres, reg_ddf_savepath_fullres,
|
| eval_dir,
|
| ]:
|
| os.makedirs(p, exist_ok=True)
|
|
|
|
|
|
|
| skip_self = True
|
|
|
|
|
|
|
|
|
| EXCLUDE_LABELS = {
|
| "brain",
|
| "tumor",
|
| "tumour",
|
| "noisy",
|
| }
|
|
|
| EXCLUDE_SUBSTRINGS = {"lesion", "tumor", "tumour"}
|
|
|
|
|
|
|
|
|
|
|
| def center_pad_to_cube(volume):
|
| """Pad volume to a cube using the max dimension, with symmetric (center) padding."""
|
| max_dim = max(volume.shape[:3])
|
| pad_width = []
|
| for s in volume.shape[:3]:
|
| total_pad = max_dim - s
|
| pad_before = total_pad // 2
|
| pad_after = total_pad - pad_before
|
| pad_width.append((pad_before, pad_after))
|
| for _ in range(volume.ndim - 3):
|
| pad_width.append((0, 0))
|
| return np.pad(volume, pad_width, mode="constant", constant_values=0)
|
|
|
|
|
| def load_fullres_volume(key, ds):
|
| """Load original-resolution volume: axis reorder, clamp, normalize, center-pad to cube."""
|
| volume = sitk.ReadImage(key)
|
| volume = sitk.GetArrayFromImage(volume)
|
| volume = reverse_axis_order(volume)
|
| if volume.ndim == 4:
|
| channel_ids = ds.get_channel_ids(key)
|
| channel_id = channel_ids[0] if len(channel_ids) > 0 else 0
|
| volume = volume[:, :, :, channel_id]
|
| if ds.clamp_range is not None:
|
| modality = ds.ALLdata_filtered[key].get("Modality", None)
|
| if modality == "CT":
|
| volume = np.clip(volume, ds.clamp_range[0], ds.clamp_range[1])
|
| volume = ds.normalize(volume)
|
| volume = center_pad_to_cube(volume)
|
| return volume
|
|
|
|
|
| def load_fullres_label(key, ds, label_key):
|
| """Load original-resolution label: axis reorder, center-pad to cube."""
|
| label_path_dict = ds.ALLdata_filtered[key].get("Label_path", {})
|
| task_labels = label_path_dict.get("segmentation", {})
|
| if label_key not in task_labels:
|
| return None
|
| label = sitk.ReadImage(task_labels[label_key])
|
| label = sitk.GetArrayFromImage(label)
|
| label = reverse_axis_order(label)
|
| if label.ndim > 3:
|
| channel_ids = ds.get_channel_ids(key)
|
| if len(channel_ids) != 0:
|
| label = label[..., channel_ids]
|
| label = center_pad_to_cube(label)
|
| return label
|
|
|
|
|
| def get_volume_name(key):
|
| """Extract a short name from a NIfTI file path."""
|
| name = os.path.basename(key)
|
| for ext in [".nii.gz", ".nii"]:
|
| if name.endswith(ext):
|
| name = name[: -len(ext)]
|
| break
|
| return name
|
|
|
|
|
| def is_organ_label(label_key):
|
| """Return True if label_key is an organ (not tumour/lesion)."""
|
| lk_lower = label_key.lower()
|
| if lk_lower in EXCLUDE_LABELS:
|
| return False
|
| return not any(kw in lk_lower for kw in EXCLUDE_SUBSTRINGS)
|
|
|
|
|
|
|
|
|
|
|
| def _surface_distances(pred, gt):
|
| """Compute directed surface distances between two binary masks.
|
|
|
| Returns (dist_pred_to_gt, dist_gt_to_pred) arrays, or (None, None)
|
| if either mask is empty or has no extractable surface.
|
| """
|
| pred_bool = pred > 0.5
|
| gt_bool = gt > 0.5
|
|
|
| if not np.any(pred_bool) or not np.any(gt_bool):
|
| return None, None
|
|
|
|
|
| struct = None
|
| pred_surface = pred_bool ^ binary_erosion(pred_bool, structure=struct)
|
| gt_surface = gt_bool ^ binary_erosion(gt_bool, structure=struct)
|
|
|
|
|
| if not np.any(pred_surface):
|
| pred_surface = pred_bool
|
| if not np.any(gt_surface):
|
| gt_surface = gt_bool
|
|
|
| dt_gt = distance_transform_edt(~gt_surface)
|
| dt_pred = distance_transform_edt(~pred_surface)
|
|
|
| return dt_gt[pred_surface], dt_pred[gt_surface]
|
|
|
|
|
| def compute_dsc(pred, gt):
|
| """Dice Similarity Coefficient."""
|
| pred_bool = pred > 0.5
|
| gt_bool = gt > 0.5
|
| intersection = np.sum(pred_bool & gt_bool)
|
| denom = np.sum(pred_bool) + np.sum(gt_bool)
|
| if denom == 0:
|
| return 1.0
|
| return 2.0 * float(intersection) / float(denom)
|
|
|
|
|
| def compute_asd(pred, gt):
|
| """Average (symmetric) Surface Distance."""
|
| d1, d2 = _surface_distances(pred, gt)
|
| if d1 is None:
|
| return float("nan")
|
| return (np.mean(d1) + np.mean(d2)) / 2.0
|
|
|
|
|
| def compute_hd(pred, gt):
|
| """Hausdorff Distance (maximum of directed HDs)."""
|
| d1, d2 = _surface_distances(pred, gt)
|
| if d1 is None:
|
| return float("nan")
|
| return float(max(np.max(d1), np.max(d2)))
|
|
|
|
|
| def compute_negdetj_pct(ddf, ndims=3):
|
| """Percent of voxels with negative Jacobian determinant.
|
|
|
| Args:
|
| ddf: displacement field tensor [1, ndims, ...] or numpy array.
|
| ndims: 2 or 3.
|
| Returns:
|
| Percentage of voxels where det(Jacobian) < 0.
|
| """
|
| if isinstance(ddf, torch.Tensor):
|
| ddf = ddf.detach().cpu().numpy()
|
| if ddf.ndim == ndims + 2:
|
| ddf = ddf[0]
|
|
|
| if ndims == 3:
|
| dux_dx = np.diff(ddf[0], axis=0, append=ddf[0, -1:, :, :])
|
| duy_dx = np.diff(ddf[1], axis=0, append=ddf[1, -1:, :, :])
|
| duz_dx = np.diff(ddf[2], axis=0, append=ddf[2, -1:, :, :])
|
|
|
| dux_dy = np.diff(ddf[0], axis=1, append=ddf[0, :, -1:, :])
|
| duy_dy = np.diff(ddf[1], axis=1, append=ddf[1, :, -1:, :])
|
| duz_dy = np.diff(ddf[2], axis=1, append=ddf[2, :, -1:, :])
|
|
|
| dux_dz = np.diff(ddf[0], axis=2, append=ddf[0, :, :, -1:])
|
| duy_dz = np.diff(ddf[1], axis=2, append=ddf[1, :, :, -1:])
|
| duz_dz = np.diff(ddf[2], axis=2, append=ddf[2, :, :, -1:])
|
|
|
| j11 = 1.0 + dux_dx; j12 = dux_dy; j13 = dux_dz
|
| j21 = duy_dx; j22 = 1.0 + duy_dy; j23 = duy_dz
|
| j31 = duz_dx; j32 = duz_dy; j33 = 1.0 + duz_dz
|
|
|
| detj = (
|
| j11 * (j22 * j33 - j23 * j32)
|
| - j12 * (j21 * j33 - j23 * j31)
|
| + j13 * (j21 * j32 - j22 * j31)
|
| )
|
| elif ndims == 2:
|
| dux_dx = np.diff(ddf[0], axis=0, append=ddf[0, -1:, :])
|
| duy_dx = np.diff(ddf[1], axis=0, append=ddf[1, -1:, :])
|
|
|
| dux_dy = np.diff(ddf[0], axis=1, append=ddf[0, :, -1:])
|
| duy_dy = np.diff(ddf[1], axis=1, append=ddf[1, :, -1:])
|
|
|
| detj = (1.0 + dux_dx) * (1.0 + duy_dy) - dux_dy * duy_dx
|
| else:
|
| raise ValueError(f"Unsupported ndims={ndims}")
|
|
|
| n_neg = np.sum(detj < 0)
|
| n_total = detj.size
|
| return 100.0 * float(n_neg) / float(n_total)
|
|
|
|
|
|
|
|
|
| keys = list(dataset.ALLdata_filtered.keys())
|
| if args.max_samples > 0:
|
| keys = keys[: args.max_samples]
|
| print(f"Total subjects: {len(keys)} (max_samples={args.max_samples or 'all'})")
|
|
|
| subjects = []
|
| for key in tqdm(keys, desc="Loading subjects"):
|
| fullres_vol = load_fullres_volume(key, dataset)
|
| om.set_init_img(fullres_vol)
|
| img_model = om._init_img.clone()
|
| img_fullres = om._init_img_raw.clone()
|
| orig_sz = list(img_fullres.shape[2:])
|
|
|
| masks_model = []
|
| masks_fullres = []
|
| for lk in label_keys:
|
| lab = load_fullres_label(key, dataset, lk)
|
| model_t, fullres_t = om._standardize_label(lab)
|
| masks_model.append(model_t)
|
| masks_fullres.append(fullres_t)
|
|
|
| if masks_model:
|
| mask_model = torch.cat(masks_model, dim=1)
|
| mask_fullres = torch.cat(masks_fullres, dim=1)
|
| else:
|
| mask_model = None
|
| mask_fullres = None
|
|
|
| subjects.append({
|
| "key": key,
|
| "img_model": img_model,
|
| "img_fullres": img_fullres,
|
| "mask_model": mask_model,
|
| "mask_fullres": mask_fullres,
|
| "orig_sz": orig_sz,
|
| })
|
|
|
| print(f"Loaded {len(subjects)} subjects into memory.")
|
|
|
|
|
|
|
| vol_names = [get_volume_name(subj["key"]) for subj in subjects]
|
|
|
|
|
| _seen = {}
|
| for i, vn in enumerate(vol_names):
|
| _seen.setdefault(vn, []).append(i)
|
| for vn, indices in _seen.items():
|
| if len(indices) > 1:
|
| for idx in indices:
|
| vol_names[idx] = f"{vn}_{idx}"
|
|
|
| organ_label_indices = []
|
| for c, lk in enumerate(label_keys):
|
| if is_organ_label(lk):
|
| organ_label_indices.append((c, lk))
|
|
|
| if organ_label_indices:
|
| print(f"Organ labels for evaluation: {[lk for _, lk in organ_label_indices]}")
|
| else:
|
| print("No organ labels found — skipping evaluation metrics.")
|
|
|
|
|
| metrics = {
|
| lk: {"dsc": {}, "asd": {}, "hd": {}}
|
| for _, lk in organ_label_indices
|
| }
|
|
|
|
|
| negdetj_pct = {}
|
|
|
|
|
|
|
| with torch.no_grad():
|
| for t, tgt in enumerate(tqdm(subjects, desc="Targets")):
|
| tgt_tag = f"Tgt{t:04d}"
|
|
|
|
|
| nib.save(
|
| utils.converet_to_nibabel(tgt["img_model"], ndims=ndims),
|
| os.path.join(reg_img_savepath, f"{tgt_tag}_ORG.nii.gz"),
|
| )
|
| if tgt["mask_model"] is not None:
|
| nib.save(
|
| utils.converet_to_nibabel(tgt["mask_model"], ndims=ndims),
|
| os.path.join(reg_msk_savepath, f"{tgt_tag}_ORG_GT.nii.gz"),
|
| )
|
|
|
|
|
| nib.save(
|
| utils.converet_to_nibabel(tgt["img_fullres"], ndims=ndims),
|
| os.path.join(reg_img_savepath_fullres, f"{tgt_tag}_ORG.nii.gz"),
|
| )
|
| if tgt["mask_fullres"] is not None:
|
| nib.save(
|
| utils.converet_to_nibabel(tgt["mask_fullres"], ndims=ndims),
|
| os.path.join(reg_msk_savepath_fullres, f"{tgt_tag}_ORG_GT.nii.gz"),
|
| )
|
|
|
|
|
| for s, src in enumerate(subjects):
|
| if skip_self and s == t:
|
| continue
|
|
|
| pair_tag = f"Tgt{t:04d}_Src{s:04d}"
|
| print(f" Registering {pair_tag}")
|
|
|
| om.set_init_img(src["img_model"])
|
| om.set_cond_img(tgt["img_model"].clone().detach())
|
|
|
| om.predict(
|
| T=[None, timesteps],
|
| proc_type=condition_type,
|
| )
|
|
|
| ddf_comp = om.get_def()
|
|
|
|
|
| neg_pct = compute_negdetj_pct(ddf_comp, ndims=ndims)
|
| negdetj_pct[(t, s)] = neg_pct
|
| print(f" %|J|<0 = {neg_pct:.4f}%")
|
|
|
|
|
| img_rec = om.apply_def(
|
| img=src["img_model"], ddf=ddf_comp, padding_mode="zeros",
|
| )
|
| nib.save(
|
| utils.converet_to_nibabel(img_rec, ndims=ndims),
|
| os.path.join(reg_img_savepath, f"{pair_tag}.nii.gz"),
|
| )
|
|
|
|
|
| msk_rec = None
|
| if src["mask_model"] is not None:
|
| msk_rec = om.apply_def(
|
| img=src["mask_model"], ddf=ddf_comp,
|
| padding_mode="zeros", resample_mode="nearest",
|
| )
|
| nib.save(
|
| utils.converet_to_nibabel(msk_rec, ndims=ndims),
|
| os.path.join(reg_msk_savepath, f"{pair_tag}_GT.nii.gz"),
|
| )
|
|
|
|
|
| nib.save(
|
| utils.converet_to_nibabel(ddf_comp, ndims=ndims),
|
| os.path.join(reg_ddf_savepath, f"{pair_tag}.nii.gz"),
|
| )
|
|
|
|
|
| img_rec_fullres = om.apply_def(
|
| img=src["img_fullres"], ddf=ddf_comp, padding_mode="border",
|
| )
|
| nib.save(
|
| utils.converet_to_nibabel(img_rec_fullres, ndims=ndims),
|
| os.path.join(reg_img_savepath_fullres, f"{pair_tag}.nii.gz"),
|
| )
|
|
|
|
|
| msk_rec_fullres = None
|
| if src["mask_fullres"] is not None:
|
| msk_rec_fullres = om.apply_def(
|
| img=src["mask_fullres"], ddf=ddf_comp,
|
| padding_mode="zeros", resample_mode="nearest",
|
| )
|
| nib.save(
|
| utils.converet_to_nibabel(msk_rec_fullres, ndims=ndims),
|
| os.path.join(reg_msk_savepath_fullres, f"{pair_tag}_GT.nii.gz"),
|
| )
|
|
|
|
|
| ddf_fullres = F.interpolate(
|
| ddf_comp, size=src["orig_sz"], mode="trilinear", align_corners=False,
|
| )
|
| nib.save(
|
| utils.converet_to_nibabel(ddf_fullres, ndims=ndims),
|
| os.path.join(reg_ddf_savepath_fullres, f"{pair_tag}.nii.gz"),
|
| )
|
|
|
|
|
| if (
|
| organ_label_indices
|
| and msk_rec_fullres is not None
|
| and tgt["mask_fullres"] is not None
|
| ):
|
| for c, lk in organ_label_indices:
|
| tgt_mask_np = tgt["mask_fullres"][0, c].cpu().numpy()
|
| reg_mask_np = msk_rec_fullres[0, c].cpu().numpy()
|
|
|
|
|
| if np.all(tgt_mask_np < 0) or np.all(reg_mask_np < 0):
|
| continue
|
|
|
| dsc_val = compute_dsc(reg_mask_np, tgt_mask_np)
|
| asd_val = compute_asd(reg_mask_np, tgt_mask_np)
|
| hd_val = compute_hd(reg_mask_np, tgt_mask_np)
|
|
|
| metrics[lk]["dsc"][(t, s)] = dsc_val
|
| metrics[lk]["asd"][(t, s)] = asd_val
|
| metrics[lk]["hd"][(t, s)] = hd_val
|
|
|
| print(
|
| f" [{lk}] DSC={dsc_val:.4f} "
|
| f"ASD={asd_val:.2f} HD={hd_val:.2f}"
|
| )
|
|
|
| print("All-to-all registration complete.")
|
|
|
|
|
|
|
| n_subj = len(subjects)
|
|
|
|
|
| negdetj_csv_path = os.path.join(eval_dir, "negdetj_pct.csv")
|
| with open(negdetj_csv_path, "w", newline="") as f:
|
| writer = csv.writer(f)
|
| writer.writerow(["target \\ source"] + vol_names)
|
| for t_idx in range(n_subj):
|
| row = [vol_names[t_idx]]
|
| for s_idx in range(n_subj):
|
| val = negdetj_pct.get((t_idx, s_idx))
|
| if val is None:
|
| row.append("")
|
| else:
|
| row.append(f"{val:.6f}")
|
| writer.writerow(row)
|
| print(f"Saved {negdetj_csv_path}")
|
|
|
| for c, lk in organ_label_indices:
|
|
|
| prefix = f"{lk}_" if len(organ_label_indices) > 1 else ""
|
|
|
| for metric_name in ["dsc", "asd", "hd"]:
|
| csv_path = os.path.join(eval_dir, f"{prefix}{metric_name}.csv")
|
| with open(csv_path, "w", newline="") as f:
|
| writer = csv.writer(f)
|
|
|
| writer.writerow(["target \\ source"] + vol_names)
|
| for t in range(n_subj):
|
| row = [vol_names[t]]
|
| for s in range(n_subj):
|
| val = metrics[lk][metric_name].get((t, s))
|
| if val is None:
|
| row.append("")
|
| elif np.isnan(val):
|
| row.append("NaN")
|
| else:
|
| row.append(f"{val:.6f}")
|
| writer.writerow(row)
|
| print(f"Saved {csv_path}")
|
|
|
|
|
| overall_path = os.path.join(eval_dir, "overall.csv")
|
| with open(overall_path, "w", newline="") as f:
|
| writer = csv.writer(f)
|
| writer.writerow(["label", "metric", "mean", "std", "n_pairs"])
|
|
|
| negdetj_vals = [v for v in negdetj_pct.values() if not np.isnan(v)]
|
| if negdetj_vals:
|
| writer.writerow([
|
| "ALL", "%|J|<0",
|
| f"{np.mean(negdetj_vals):.6f}", f"{np.std(negdetj_vals):.6f}",
|
| len(negdetj_vals),
|
| ])
|
| for _, lk in organ_label_indices:
|
| for metric_name in ["dsc", "asd", "hd"]:
|
| vals = [
|
| v for v in metrics[lk][metric_name].values()
|
| if not np.isnan(v)
|
| ]
|
| if vals:
|
| mean_val = np.mean(vals)
|
| std_val = np.std(vals)
|
| n_pairs = len(vals)
|
| else:
|
| mean_val = float("nan")
|
| std_val = float("nan")
|
| n_pairs = 0
|
| writer.writerow([
|
| lk,
|
| metric_name.upper(),
|
| f"{mean_val:.6f}" if not np.isnan(mean_val) else "NaN",
|
| f"{std_val:.6f}" if not np.isnan(std_val) else "NaN",
|
| n_pairs,
|
| ])
|
| print(f"Saved {overall_path}")
|
|
|