| """
|
| OM_reg_unpair_ext.py — Unpaired all-to-all registration using OMorpher
|
| with an external Learn2Reg-style dataset JSON (e.g. OASIS).
|
|
|
| Extracts unique subjects from the registration_val or registration_test
|
| pairs, then registers every subject to every other subject. Supports
|
| multi-class label maps (e.g. 35 brain regions) with auto-discovered
|
| label IDs. Saves registered images, masks, DDFs, and evaluation metrics
|
| (DSC, ASD, HD) per label class.
|
|
|
| Usage:
|
| python Scripts/OM_reg_unpair_ext.py -C Config/config_reg_brain.yaml \
|
| --dataset-json ~/rds/rds-airr-p51-TWhPgQVLKbA/Code/Registration/Dataset/OASIS/OASIS_dataset.json \
|
| --split val
|
|
|
| python Scripts/OM_reg_unpair_ext.py -C Config/config_reg_brain.yaml \
|
| --dataset-json ~/rds/rds-airr-p51-TWhPgQVLKbA/Code/Registration/Dataset/OASIS/OASIS_dataset.json \
|
| --split test -N 10
|
| """
|
|
|
| import os
|
| import sys
|
|
|
|
|
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
| import csv
|
| import json
|
| import numpy as np
|
| import torch
|
| import torch.nn.functional as F
|
| import nibabel as nib
|
| import yaml
|
| import SimpleITK as sitk
|
| from scipy.ndimage import distance_transform_edt, binary_erosion
|
| from tqdm import tqdm
|
|
|
| import utils
|
| from Dataloader.dataLoader import reverse_axis_order
|
| from OMorpher import OMorpher
|
|
|
|
|
|
|
| import argparse
|
|
|
| parser = argparse.ArgumentParser()
|
| parser.add_argument(
|
| "--config", "-C",
|
| help="Path for the config file",
|
| type=str,
|
| default="Config/config_reg_brain.yaml",
|
| required=False,
|
| )
|
| parser.add_argument(
|
| "--dataset-json",
|
| help="Path to the Learn2Reg-style dataset JSON",
|
| type=str,
|
| default="~/rds/rds-airr-p51-TWhPgQVLKbA/Code/Registration/Dataset/OASIS/OASIS_dataset.json",
|
| )
|
| parser.add_argument(
|
| "--split",
|
| help="Which registration split to use: 'val' or 'test'",
|
| type=str,
|
| choices=["val", "test"],
|
| default="val",
|
| )
|
| parser.add_argument(
|
| "--max-samples", "-N",
|
| help="Max number of subjects to include (0 = all)",
|
| type=int,
|
| default=0,
|
| )
|
| args = parser.parse_args()
|
|
|
|
|
|
|
| with open(args.config, "r") as file:
|
| hyp_parameters = yaml.safe_load(file)
|
| print(hyp_parameters)
|
|
|
| hyp_parameters["batchsize"] = 1
|
| model_img_sz = hyp_parameters["img_size"]
|
| timesteps = hyp_parameters["timesteps"]
|
| condition_type = hyp_parameters["condition_type"]
|
| ndims = hyp_parameters["ndims"]
|
|
|
|
|
|
|
| dataset_json_path = os.path.expanduser(args.dataset_json)
|
| dataset_root = os.path.dirname(dataset_json_path)
|
|
|
| with open(dataset_json_path, "r") as f:
|
| dataset_meta = json.load(f)
|
|
|
| dataset_name = dataset_meta.get("name", "UnknownDataset")
|
| print(f"Dataset: {dataset_name}")
|
|
|
|
|
| if args.split == "val":
|
| pairs = dataset_meta.get("registration_val", [])
|
| elif args.split == "test":
|
| pairs = dataset_meta.get("registration_test", [])
|
| else:
|
| raise ValueError(f"Unknown split: {args.split}")
|
|
|
| print(f"Split: {args.split}, Pairs in JSON: {len(pairs)}")
|
|
|
|
|
| _seen_paths = {}
|
| for pair in pairs:
|
| for role in ("fixed", "moving"):
|
| rel = pair[role]
|
| if rel not in _seen_paths:
|
| _seen_paths[rel] = len(_seen_paths)
|
|
|
| subject_rel_paths = list(_seen_paths.keys())
|
|
|
| if args.max_samples > 0:
|
| subject_rel_paths = subject_rel_paths[: args.max_samples]
|
|
|
| print(f"Unique subjects: {len(subject_rel_paths)} (max_samples={args.max_samples or 'all'})")
|
|
|
|
|
| _label_lookup = {}
|
| for entry in dataset_meta.get("training", []):
|
| img_base = os.path.basename(entry["image"])
|
| _label_lookup[img_base] = entry.get("label")
|
| for entry in dataset_meta.get("test", []):
|
| img_base = os.path.basename(entry.get("image", ""))
|
| if entry.get("label"):
|
| _label_lookup[img_base] = entry["label"]
|
|
|
|
|
|
|
| epoch = f'{hyp_parameters["model_id_str"]}_{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}'
|
| model_save_path = os.path.join(
|
| f'Models/{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}/',
|
| str(epoch) + ".pth",
|
| )
|
| print("Loading model from:", model_save_path)
|
|
|
| om = OMorpher(
|
| config=hyp_parameters,
|
| checkpoint_path=model_save_path,
|
| device=str(hyp_parameters.get("device", "cpu")),
|
| )
|
| print(om)
|
|
|
|
|
|
|
| reg_img_savepath = hyp_parameters["reg_img_savepath"]
|
| reg_msk_savepath = hyp_parameters["reg_msk_savepath"]
|
| reg_ddf_savepath = hyp_parameters["reg_ddf_savepath"]
|
|
|
| reg_img_savepath_fullres = reg_img_savepath.rstrip("/") + "_fullres/"
|
| reg_msk_savepath_fullres = reg_msk_savepath.rstrip("/") + "_fullres/"
|
| reg_ddf_savepath_fullres = reg_ddf_savepath.rstrip("/") + "_fullres/"
|
|
|
| eval_dir = os.path.join(reg_img_savepath, "..", "eval")
|
|
|
| for p in [
|
| reg_img_savepath, reg_msk_savepath, reg_ddf_savepath,
|
| reg_img_savepath_fullres, reg_msk_savepath_fullres, reg_ddf_savepath_fullres,
|
| eval_dir,
|
| ]:
|
| os.makedirs(p, exist_ok=True)
|
|
|
|
|
|
|
| skip_self = True
|
|
|
|
|
|
|
|
|
|
|
| def resolve_path(rel_path):
|
| """Resolve a relative path from the dataset JSON to an absolute path."""
|
| if os.path.isabs(rel_path):
|
| return rel_path
|
| return os.path.normpath(os.path.join(dataset_root, rel_path))
|
|
|
|
|
| def load_volume(nifti_path):
|
| """Load a NIfTI volume: axis reorder only.
|
|
|
| OMorpher._standardize_img handles: normalize -> pad-to-cube -> resize.
|
| """
|
| volume = sitk.ReadImage(nifti_path)
|
| volume = sitk.GetArrayFromImage(volume)
|
| volume = reverse_axis_order(volume)
|
| if volume.ndim == 4:
|
| volume = volume[:, :, :, 0]
|
| return volume
|
|
|
|
|
| def load_label(nifti_path):
|
| """Load a NIfTI label map: axis reorder only.
|
|
|
| OMorpher._standardize_label handles: pad-to-cube -> resize (nearest).
|
| """
|
| label = sitk.ReadImage(nifti_path)
|
| label = sitk.GetArrayFromImage(label)
|
| label = reverse_axis_order(label)
|
| if label.ndim > 3:
|
| label = label[:, :, :, 0]
|
| return label
|
|
|
|
|
| def get_label_path_for_image(image_rel_path):
|
| """Find the label path for an image by looking up the training/test entries."""
|
| img_base = os.path.basename(image_rel_path)
|
|
|
| for ext in [img_base, img_base.replace(".nii.gz", ".csv")]:
|
| label_rel = _label_lookup.get(ext)
|
| if label_rel is not None:
|
|
|
| label_rel = label_rel.replace(".csv", ".nii.gz")
|
| label_abs = resolve_path(label_rel)
|
| if os.path.exists(label_abs):
|
| return label_abs
|
|
|
| img_abs = resolve_path(image_rel_path)
|
| label_abs = img_abs.replace("/images", "/labels")
|
| if os.path.exists(label_abs):
|
| return label_abs
|
| return None
|
|
|
|
|
| def get_volume_name(path):
|
| """Extract a short name from a NIfTI file path."""
|
| name = os.path.basename(path)
|
| for ext in [".nii.gz", ".nii"]:
|
| if name.endswith(ext):
|
| name = name[: -len(ext)]
|
| break
|
| return name
|
|
|
|
|
|
|
|
|
|
|
| def discover_label_ids(label_path):
|
| """Read a label NIfTI and return sorted non-zero unique IDs."""
|
| lab = sitk.ReadImage(label_path)
|
| lab = sitk.GetArrayFromImage(lab)
|
| unique = np.unique(lab).astype(int)
|
| return sorted([int(v) for v in unique if v > 0])
|
|
|
|
|
|
|
|
|
|
|
| def _surface_distances(pred, gt):
|
| """Compute directed surface distances between two binary masks."""
|
| pred_bool = pred > 0.5
|
| gt_bool = gt > 0.5
|
|
|
| if not np.any(pred_bool) or not np.any(gt_bool):
|
| return None, None
|
|
|
| struct = None
|
| pred_surface = pred_bool ^ binary_erosion(pred_bool, structure=struct)
|
| gt_surface = gt_bool ^ binary_erosion(gt_bool, structure=struct)
|
|
|
| if not np.any(pred_surface):
|
| pred_surface = pred_bool
|
| if not np.any(gt_surface):
|
| gt_surface = gt_bool
|
|
|
| dt_gt = distance_transform_edt(~gt_surface)
|
| dt_pred = distance_transform_edt(~pred_surface)
|
|
|
| return dt_gt[pred_surface], dt_pred[gt_surface]
|
|
|
|
|
| def compute_dsc(pred, gt):
|
| """Dice Similarity Coefficient."""
|
| pred_bool = pred > 0.5
|
| gt_bool = gt > 0.5
|
| intersection = np.sum(pred_bool & gt_bool)
|
| denom = np.sum(pred_bool) + np.sum(gt_bool)
|
| if denom == 0:
|
| return 1.0
|
| return 2.0 * float(intersection) / float(denom)
|
|
|
|
|
| def compute_asd(pred, gt):
|
| """Average (symmetric) Surface Distance."""
|
| d1, d2 = _surface_distances(pred, gt)
|
| if d1 is None:
|
| return float("nan")
|
| return (np.mean(d1) + np.mean(d2)) / 2.0
|
|
|
|
|
| def compute_hd(pred, gt):
|
| """Hausdorff Distance (maximum of directed HDs)."""
|
| d1, d2 = _surface_distances(pred, gt)
|
| if d1 is None:
|
| return float("nan")
|
| return float(max(np.max(d1), np.max(d2)))
|
|
|
|
|
|
|
|
|
| subjects = []
|
| organ_label_ids = None
|
|
|
| for rel_path in tqdm(subject_rel_paths, desc="Loading subjects"):
|
| abs_path = resolve_path(rel_path)
|
| vol = load_volume(abs_path)
|
|
|
| om.set_init_img(vol)
|
| img_model = om._init_img.clone()
|
| img_fullres = om._init_img_raw.clone()
|
| orig_sz = list(img_fullres.shape[2:])
|
|
|
|
|
| label_path = get_label_path_for_image(rel_path)
|
| label_model, label_fullres = None, None
|
| if label_path is not None and os.path.exists(label_path):
|
| lab = load_label(label_path)
|
| label_model, label_fullres = om._standardize_label(lab)
|
|
|
|
|
| if organ_label_ids is None:
|
| organ_label_ids = discover_label_ids(label_path)
|
| print(f"Auto-discovered {len(organ_label_ids)} label classes: {organ_label_ids}")
|
|
|
| subjects.append({
|
| "rel_path": rel_path,
|
| "img_model": img_model,
|
| "img_fullres": img_fullres,
|
| "label_model": label_model,
|
| "label_fullres": label_fullres,
|
| "orig_sz": orig_sz,
|
| })
|
|
|
| print(f"Loaded {len(subjects)} subjects into memory.")
|
|
|
| if organ_label_ids is None:
|
| organ_label_ids = []
|
| print("No labels found — skipping evaluation metrics.")
|
| else:
|
| print(f"Organ labels for evaluation: {organ_label_ids}")
|
|
|
|
|
|
|
| vol_names = [get_volume_name(subj["rel_path"]) for subj in subjects]
|
|
|
|
|
| _seen = {}
|
| for i, vn in enumerate(vol_names):
|
| _seen.setdefault(vn, []).append(i)
|
| for vn, indices in _seen.items():
|
| if len(indices) > 1:
|
| for idx in indices:
|
| vol_names[idx] = f"{vn}_{idx}"
|
|
|
|
|
| metrics = {
|
| cid: {"dsc": {}, "asd": {}, "hd": {}}
|
| for cid in organ_label_ids
|
| }
|
|
|
| metrics_pre = {
|
| cid: {"dsc": {}, "asd": {}, "hd": {}}
|
| for cid in organ_label_ids
|
| }
|
|
|
|
|
|
|
| with torch.no_grad():
|
| for t, tgt in enumerate(tqdm(subjects, desc="Targets")):
|
| tgt_tag = f"Tgt{t:04d}"
|
|
|
|
|
| nib.save(
|
| utils.converet_to_nibabel(tgt["img_model"], ndims=ndims),
|
| os.path.join(reg_img_savepath, f"{tgt_tag}_ORG.nii.gz"),
|
| )
|
| if tgt["label_model"] is not None:
|
| nib.save(
|
| utils.converet_to_nibabel(tgt["label_model"], ndims=ndims),
|
| os.path.join(reg_msk_savepath, f"{tgt_tag}_ORG_GT.nii.gz"),
|
| )
|
|
|
|
|
| nib.save(
|
| utils.converet_to_nibabel(tgt["img_fullres"], ndims=ndims),
|
| os.path.join(reg_img_savepath_fullres, f"{tgt_tag}_ORG.nii.gz"),
|
| )
|
| if tgt["label_fullres"] is not None:
|
| nib.save(
|
| utils.converet_to_nibabel(tgt["label_fullres"], ndims=ndims),
|
| os.path.join(reg_msk_savepath_fullres, f"{tgt_tag}_ORG_GT.nii.gz"),
|
| )
|
|
|
|
|
| for s, src in enumerate(subjects):
|
| if skip_self and s == t:
|
| continue
|
|
|
| pair_tag = f"Tgt{t:04d}_Src{s:04d}"
|
| print(f" Registering {pair_tag}")
|
|
|
| om.set_init_img(src["img_model"])
|
| om.set_cond_img(tgt["img_model"].clone().detach())
|
|
|
| om.predict(
|
| T=[None, timesteps],
|
| proc_type=condition_type,
|
| )
|
|
|
| ddf_comp = om.get_def()
|
|
|
|
|
| img_rec = om.apply_def(
|
| img=src["img_model"], ddf=ddf_comp, padding_mode="zeros",
|
| )
|
| nib.save(
|
| utils.converet_to_nibabel(img_rec, ndims=ndims),
|
| os.path.join(reg_img_savepath, f"{pair_tag}.nii.gz"),
|
| )
|
|
|
|
|
| label_rec = None
|
| if src["label_model"] is not None:
|
| label_rec = om.apply_def(
|
| img=src["label_model"], ddf=ddf_comp,
|
| padding_mode="zeros", resample_mode="nearest",
|
| )
|
| nib.save(
|
| utils.converet_to_nibabel(label_rec, ndims=ndims),
|
| os.path.join(reg_msk_savepath, f"{pair_tag}_GT.nii.gz"),
|
| )
|
|
|
|
|
| nib.save(
|
| utils.converet_to_nibabel(ddf_comp, ndims=ndims),
|
| os.path.join(reg_ddf_savepath, f"{pair_tag}.nii.gz"),
|
| )
|
|
|
|
|
| img_rec_fullres = om.apply_def(
|
| img=src["img_fullres"], ddf=ddf_comp, padding_mode="border",
|
| )
|
| nib.save(
|
| utils.converet_to_nibabel(img_rec_fullres, ndims=ndims),
|
| os.path.join(reg_img_savepath_fullres, f"{pair_tag}.nii.gz"),
|
| )
|
|
|
|
|
| label_rec_fullres = None
|
| if src["label_fullres"] is not None:
|
| label_rec_fullres = om.apply_def(
|
| img=src["label_fullres"], ddf=ddf_comp,
|
| padding_mode="zeros", resample_mode="nearest",
|
| )
|
| nib.save(
|
| utils.converet_to_nibabel(label_rec_fullres, ndims=ndims),
|
| os.path.join(reg_msk_savepath_fullres, f"{pair_tag}_GT.nii.gz"),
|
| )
|
|
|
|
|
| ddf_fullres = F.interpolate(
|
| ddf_comp, size=src["orig_sz"], mode="trilinear", align_corners=False,
|
| )
|
| nib.save(
|
| utils.converet_to_nibabel(ddf_fullres, ndims=ndims),
|
| os.path.join(reg_ddf_savepath_fullres, f"{pair_tag}.nii.gz"),
|
| )
|
|
|
|
|
| if (
|
| organ_label_ids
|
| and label_rec_fullres is not None
|
| and tgt["label_fullres"] is not None
|
| ):
|
| tgt_label_np = tgt["label_fullres"][0, 0].cpu().numpy()
|
| src_label_np = src["label_fullres"][0, 0].cpu().numpy()
|
| reg_label_np = label_rec_fullres[0, 0].cpu().numpy()
|
|
|
| for cid in organ_label_ids:
|
| tgt_mask = (np.round(tgt_label_np) == cid).astype(np.float32)
|
| src_mask = (np.round(src_label_np) == cid).astype(np.float32)
|
| reg_mask = (np.round(reg_label_np) == cid).astype(np.float32)
|
|
|
|
|
| if np.sum(tgt_mask) == 0 and np.sum(src_mask) == 0:
|
| continue
|
|
|
|
|
| pre_dsc = compute_dsc(src_mask, tgt_mask)
|
| pre_asd = compute_asd(src_mask, tgt_mask)
|
| pre_hd = compute_hd(src_mask, tgt_mask)
|
|
|
| metrics_pre[cid]["dsc"][(t, s)] = pre_dsc
|
| metrics_pre[cid]["asd"][(t, s)] = pre_asd
|
| metrics_pre[cid]["hd"][(t, s)] = pre_hd
|
|
|
|
|
| post_dsc = compute_dsc(reg_mask, tgt_mask)
|
| post_asd = compute_asd(reg_mask, tgt_mask)
|
| post_hd = compute_hd(reg_mask, tgt_mask)
|
|
|
| metrics[cid]["dsc"][(t, s)] = post_dsc
|
| metrics[cid]["asd"][(t, s)] = post_asd
|
| metrics[cid]["hd"][(t, s)] = post_hd
|
|
|
|
|
| post_dscs = [
|
| metrics[cid]["dsc"][(t, s)]
|
| for cid in organ_label_ids
|
| if (t, s) in metrics[cid]["dsc"]
|
| ]
|
| if post_dscs:
|
| print(
|
| f" Mean DSC: pre={np.mean([metrics_pre[cid]['dsc'].get((t,s), float('nan')) for cid in organ_label_ids if (t,s) in metrics_pre[cid]['dsc']]):.4f} "
|
| f"post={np.mean(post_dscs):.4f}"
|
| )
|
|
|
| print("\nAll-to-all unpaired registration complete.")
|
|
|
|
|
|
|
| n_subj = len(subjects)
|
|
|
|
|
| def _fmt(val):
|
| if val is None:
|
| return ""
|
| if np.isnan(val):
|
| return "NaN"
|
| return f"{val:.6f}"
|
|
|
|
|
|
|
| for cid in organ_label_ids:
|
| prefix = f"label{cid:02d}_"
|
|
|
| for metric_name in ["dsc", "asd", "hd"]:
|
| csv_path = os.path.join(eval_dir, f"{prefix}{metric_name}.csv")
|
| with open(csv_path, "w", newline="") as f:
|
| writer = csv.writer(f)
|
| writer.writerow(["target \\ source"] + vol_names)
|
| for t_idx in range(n_subj):
|
| row = [vol_names[t_idx]]
|
| for s_idx in range(n_subj):
|
| val = metrics[cid][metric_name].get((t_idx, s_idx))
|
| row.append(_fmt(val))
|
| writer.writerow(row)
|
| print(f"Saved {csv_path}")
|
|
|
|
|
| for cid in organ_label_ids:
|
| prefix = f"label{cid:02d}_pre_"
|
|
|
| for metric_name in ["dsc", "asd", "hd"]:
|
| csv_path = os.path.join(eval_dir, f"{prefix}{metric_name}.csv")
|
| with open(csv_path, "w", newline="") as f:
|
| writer = csv.writer(f)
|
| writer.writerow(["target \\ source"] + vol_names)
|
| for t_idx in range(n_subj):
|
| row = [vol_names[t_idx]]
|
| for s_idx in range(n_subj):
|
| val = metrics_pre[cid][metric_name].get((t_idx, s_idx))
|
| row.append(_fmt(val))
|
| writer.writerow(row)
|
| print(f"Saved {csv_path}")
|
|
|
|
|
| overall_path = os.path.join(eval_dir, "overall.csv")
|
| with open(overall_path, "w", newline="") as f:
|
| writer = csv.writer(f)
|
| writer.writerow([
|
| "label_id", "metric",
|
| "pre_mean", "pre_std",
|
| "post_mean", "post_std",
|
| "n_pairs",
|
| ])
|
| for cid in organ_label_ids:
|
| for metric_name in ["dsc", "asd", "hd"]:
|
| pre_vals = [
|
| v for v in metrics_pre[cid][metric_name].values()
|
| if not np.isnan(v)
|
| ]
|
| post_vals = [
|
| v for v in metrics[cid][metric_name].values()
|
| if not np.isnan(v)
|
| ]
|
| pre_mean = np.mean(pre_vals) if pre_vals else float("nan")
|
| pre_std = np.std(pre_vals) if pre_vals else float("nan")
|
| post_mean = np.mean(post_vals) if post_vals else float("nan")
|
| post_std = np.std(post_vals) if post_vals else float("nan")
|
| n = max(len(pre_vals), len(post_vals))
|
| writer.writerow([
|
| cid,
|
| metric_name.upper(),
|
| _fmt(pre_mean), _fmt(pre_std),
|
| _fmt(post_mean), _fmt(post_std),
|
| n,
|
| ])
|
| print(f"Saved {overall_path}")
|
|
|
|
|
| grand_path = os.path.join(eval_dir, "grand_summary.csv")
|
| with open(grand_path, "w", newline="") as f:
|
| writer = csv.writer(f)
|
| writer.writerow(["metric", "pre_mean", "pre_std", "post_mean", "post_std", "n_classes"])
|
| for metric_name in ["dsc", "asd", "hd"]:
|
| pre_class_means = []
|
| post_class_means = []
|
| for cid in organ_label_ids:
|
| pre_vals = [
|
| v for v in metrics_pre[cid][metric_name].values()
|
| if not np.isnan(v)
|
| ]
|
| post_vals = [
|
| v for v in metrics[cid][metric_name].values()
|
| if not np.isnan(v)
|
| ]
|
| if pre_vals:
|
| pre_class_means.append(np.mean(pre_vals))
|
| if post_vals:
|
| post_class_means.append(np.mean(post_vals))
|
| writer.writerow([
|
| metric_name.upper(),
|
| _fmt(np.mean(pre_class_means) if pre_class_means else float("nan")),
|
| _fmt(np.std(pre_class_means) if pre_class_means else float("nan")),
|
| _fmt(np.mean(post_class_means) if post_class_means else float("nan")),
|
| _fmt(np.std(post_class_means) if post_class_means else float("nan")),
|
| len(organ_label_ids),
|
| ])
|
| print(f"Saved {grand_path}")
|
|
|