Omini3D / Scripts /OM_reg_unpair_ext.py
maxmo2009's picture
Sync from local: code + epoch-110 checkpoint, clean README
2af0e94 verified
"""
OM_reg_unpair_ext.py — Unpaired all-to-all registration using OMorpher
with an external Learn2Reg-style dataset JSON (e.g. OASIS).
Extracts unique subjects from the registration_val or registration_test
pairs, then registers every subject to every other subject. Supports
multi-class label maps (e.g. 35 brain regions) with auto-discovered
label IDs. Saves registered images, masks, DDFs, and evaluation metrics
(DSC, ASD, HD) per label class.
Usage:
python Scripts/OM_reg_unpair_ext.py -C Config/config_reg_brain.yaml \
--dataset-json ~/rds/rds-airr-p51-TWhPgQVLKbA/Code/Registration/Dataset/OASIS/OASIS_dataset.json \
--split val
python Scripts/OM_reg_unpair_ext.py -C Config/config_reg_brain.yaml \
--dataset-json ~/rds/rds-airr-p51-TWhPgQVLKbA/Code/Registration/Dataset/OASIS/OASIS_dataset.json \
--split test -N 10
"""
import os
import sys
# Add project root to path so imports work from Scripts/
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import csv
import json
import numpy as np
import torch
import torch.nn.functional as F
import nibabel as nib
import yaml
import SimpleITK as sitk
from scipy.ndimage import distance_transform_edt, binary_erosion
from tqdm import tqdm
import utils
from Dataloader.dataLoader import reverse_axis_order
from OMorpher import OMorpher
# ========== CLI ==========
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--config", "-C",
help="Path for the config file",
type=str,
default="Config/config_reg_brain.yaml",
required=False,
)
parser.add_argument(
"--dataset-json",
help="Path to the Learn2Reg-style dataset JSON",
type=str,
default="~/rds/rds-airr-p51-TWhPgQVLKbA/Code/Registration/Dataset/OASIS/OASIS_dataset.json",
)
parser.add_argument(
"--split",
help="Which registration split to use: 'val' or 'test'",
type=str,
choices=["val", "test"],
default="val",
)
parser.add_argument(
"--max-samples", "-N",
help="Max number of subjects to include (0 = all)",
type=int,
default=0,
)
args = parser.parse_args()
# ========== Config ==========
with open(args.config, "r") as file:
hyp_parameters = yaml.safe_load(file)
print(hyp_parameters)
hyp_parameters["batchsize"] = 1
model_img_sz = hyp_parameters["img_size"]
timesteps = hyp_parameters["timesteps"]
condition_type = hyp_parameters["condition_type"]
ndims = hyp_parameters["ndims"]
# ========== Load external dataset JSON ==========
dataset_json_path = os.path.expanduser(args.dataset_json)
dataset_root = os.path.dirname(dataset_json_path)
with open(dataset_json_path, "r") as f:
dataset_meta = json.load(f)
dataset_name = dataset_meta.get("name", "UnknownDataset")
print(f"Dataset: {dataset_name}")
# Select registration split
if args.split == "val":
pairs = dataset_meta.get("registration_val", [])
elif args.split == "test":
pairs = dataset_meta.get("registration_test", [])
else:
raise ValueError(f"Unknown split: {args.split}")
print(f"Split: {args.split}, Pairs in JSON: {len(pairs)}")
# Extract unique subject image paths from the pairs
_seen_paths = {}
for pair in pairs:
for role in ("fixed", "moving"):
rel = pair[role]
if rel not in _seen_paths:
_seen_paths[rel] = len(_seen_paths)
subject_rel_paths = list(_seen_paths.keys())
if args.max_samples > 0:
subject_rel_paths = subject_rel_paths[: args.max_samples]
print(f"Unique subjects: {len(subject_rel_paths)} (max_samples={args.max_samples or 'all'})")
# Build label lookup: image basename -> label relative path
_label_lookup = {}
for entry in dataset_meta.get("training", []):
img_base = os.path.basename(entry["image"])
_label_lookup[img_base] = entry.get("label")
for entry in dataset_meta.get("test", []):
img_base = os.path.basename(entry.get("image", ""))
if entry.get("label"):
_label_lookup[img_base] = entry["label"]
# ========== OMorpher setup ==========
epoch = f'{hyp_parameters["model_id_str"]}_{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}'
model_save_path = os.path.join(
f'Models/{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}/',
str(epoch) + ".pth",
)
print("Loading model from:", model_save_path)
om = OMorpher(
config=hyp_parameters,
checkpoint_path=model_save_path,
device=str(hyp_parameters.get("device", "cpu")),
)
print(om)
# ========== Output directories ==========
reg_img_savepath = hyp_parameters["reg_img_savepath"]
reg_msk_savepath = hyp_parameters["reg_msk_savepath"]
reg_ddf_savepath = hyp_parameters["reg_ddf_savepath"]
reg_img_savepath_fullres = reg_img_savepath.rstrip("/") + "_fullres/"
reg_msk_savepath_fullres = reg_msk_savepath.rstrip("/") + "_fullres/"
reg_ddf_savepath_fullres = reg_ddf_savepath.rstrip("/") + "_fullres/"
eval_dir = os.path.join(reg_img_savepath, "..", "eval")
for p in [
reg_img_savepath, reg_msk_savepath, reg_ddf_savepath,
reg_img_savepath_fullres, reg_msk_savepath_fullres, reg_ddf_savepath_fullres,
eval_dir,
]:
os.makedirs(p, exist_ok=True)
# ========== Settings ==========
skip_self = True # skip pairs where source == target
# ========== Helper functions ==========
def resolve_path(rel_path):
"""Resolve a relative path from the dataset JSON to an absolute path."""
if os.path.isabs(rel_path):
return rel_path
return os.path.normpath(os.path.join(dataset_root, rel_path))
def load_volume(nifti_path):
"""Load a NIfTI volume: axis reorder only.
OMorpher._standardize_img handles: normalize -> pad-to-cube -> resize.
"""
volume = sitk.ReadImage(nifti_path)
volume = sitk.GetArrayFromImage(volume)
volume = reverse_axis_order(volume)
if volume.ndim == 4:
volume = volume[:, :, :, 0]
return volume
def load_label(nifti_path):
"""Load a NIfTI label map: axis reorder only.
OMorpher._standardize_label handles: pad-to-cube -> resize (nearest).
"""
label = sitk.ReadImage(nifti_path)
label = sitk.GetArrayFromImage(label)
label = reverse_axis_order(label)
if label.ndim > 3:
label = label[:, :, :, 0]
return label
def get_label_path_for_image(image_rel_path):
"""Find the label path for an image by looking up the training/test entries."""
img_base = os.path.basename(image_rel_path)
# Fix extension mismatch: JSON test entries may use .csv but files are .nii.gz
for ext in [img_base, img_base.replace(".nii.gz", ".csv")]:
label_rel = _label_lookup.get(ext)
if label_rel is not None:
# Ensure we use .nii.gz extension for the actual file
label_rel = label_rel.replace(".csv", ".nii.gz")
label_abs = resolve_path(label_rel)
if os.path.exists(label_abs):
return label_abs
# Fallback: derive label path from image path (images* -> labels*)
img_abs = resolve_path(image_rel_path)
label_abs = img_abs.replace("/images", "/labels")
if os.path.exists(label_abs):
return label_abs
return None
def get_volume_name(path):
"""Extract a short name from a NIfTI file path."""
name = os.path.basename(path)
for ext in [".nii.gz", ".nii"]:
if name.endswith(ext):
name = name[: -len(ext)]
break
return name
# ---------- Auto-discover label IDs ----------
def discover_label_ids(label_path):
"""Read a label NIfTI and return sorted non-zero unique IDs."""
lab = sitk.ReadImage(label_path)
lab = sitk.GetArrayFromImage(lab)
unique = np.unique(lab).astype(int)
return sorted([int(v) for v in unique if v > 0])
# ---------- Evaluation metrics ----------
def _surface_distances(pred, gt):
"""Compute directed surface distances between two binary masks."""
pred_bool = pred > 0.5
gt_bool = gt > 0.5
if not np.any(pred_bool) or not np.any(gt_bool):
return None, None
struct = None
pred_surface = pred_bool ^ binary_erosion(pred_bool, structure=struct)
gt_surface = gt_bool ^ binary_erosion(gt_bool, structure=struct)
if not np.any(pred_surface):
pred_surface = pred_bool
if not np.any(gt_surface):
gt_surface = gt_bool
dt_gt = distance_transform_edt(~gt_surface)
dt_pred = distance_transform_edt(~pred_surface)
return dt_gt[pred_surface], dt_pred[gt_surface]
def compute_dsc(pred, gt):
"""Dice Similarity Coefficient."""
pred_bool = pred > 0.5
gt_bool = gt > 0.5
intersection = np.sum(pred_bool & gt_bool)
denom = np.sum(pred_bool) + np.sum(gt_bool)
if denom == 0:
return 1.0
return 2.0 * float(intersection) / float(denom)
def compute_asd(pred, gt):
"""Average (symmetric) Surface Distance."""
d1, d2 = _surface_distances(pred, gt)
if d1 is None:
return float("nan")
return (np.mean(d1) + np.mean(d2)) / 2.0
def compute_hd(pred, gt):
"""Hausdorff Distance (maximum of directed HDs)."""
d1, d2 = _surface_distances(pred, gt)
if d1 is None:
return float("nan")
return float(max(np.max(d1), np.max(d2)))
# ========== Pre-load all subjects ==========
subjects = []
organ_label_ids = None # auto-discovered from first label file
for rel_path in tqdm(subject_rel_paths, desc="Loading subjects"):
abs_path = resolve_path(rel_path)
vol = load_volume(abs_path)
om.set_init_img(vol)
img_model = om._init_img.clone()
img_fullres = om._init_img_raw.clone()
orig_sz = list(img_fullres.shape[2:])
# Load label (single-channel multi-class map)
label_path = get_label_path_for_image(rel_path)
label_model, label_fullres = None, None
if label_path is not None and os.path.exists(label_path):
lab = load_label(label_path)
label_model, label_fullres = om._standardize_label(lab)
# Auto-discover label IDs from the first available label
if organ_label_ids is None:
organ_label_ids = discover_label_ids(label_path)
print(f"Auto-discovered {len(organ_label_ids)} label classes: {organ_label_ids}")
subjects.append({
"rel_path": rel_path,
"img_model": img_model,
"img_fullres": img_fullres,
"label_model": label_model,
"label_fullres": label_fullres,
"orig_sz": orig_sz,
})
print(f"Loaded {len(subjects)} subjects into memory.")
if organ_label_ids is None:
organ_label_ids = []
print("No labels found — skipping evaluation metrics.")
else:
print(f"Organ labels for evaluation: {organ_label_ids}")
# ========== Prepare evaluation structures ==========
vol_names = [get_volume_name(subj["rel_path"]) for subj in subjects]
# Disambiguate duplicate basenames by appending index
_seen = {}
for i, vn in enumerate(vol_names):
_seen.setdefault(vn, []).append(i)
for vn, indices in _seen.items():
if len(indices) > 1:
for idx in indices:
vol_names[idx] = f"{vn}_{idx}"
# metrics[label_id][metric_name][(t, s)] = value (post-registration)
metrics = {
cid: {"dsc": {}, "asd": {}, "hd": {}}
for cid in organ_label_ids
}
# metrics_pre: same structure for pre-registration
metrics_pre = {
cid: {"dsc": {}, "asd": {}, "hd": {}}
for cid in organ_label_ids
}
# ========== All-to-all registration ==========
with torch.no_grad():
for t, tgt in enumerate(tqdm(subjects, desc="Targets")):
tgt_tag = f"Tgt{t:04d}"
# --- Save target original at model resolution ---
nib.save(
utils.converet_to_nibabel(tgt["img_model"], ndims=ndims),
os.path.join(reg_img_savepath, f"{tgt_tag}_ORG.nii.gz"),
)
if tgt["label_model"] is not None:
nib.save(
utils.converet_to_nibabel(tgt["label_model"], ndims=ndims),
os.path.join(reg_msk_savepath, f"{tgt_tag}_ORG_GT.nii.gz"),
)
# --- Save target original at full resolution ---
nib.save(
utils.converet_to_nibabel(tgt["img_fullres"], ndims=ndims),
os.path.join(reg_img_savepath_fullres, f"{tgt_tag}_ORG.nii.gz"),
)
if tgt["label_fullres"] is not None:
nib.save(
utils.converet_to_nibabel(tgt["label_fullres"], ndims=ndims),
os.path.join(reg_msk_savepath_fullres, f"{tgt_tag}_ORG_GT.nii.gz"),
)
# --- Inner loop: register each source to this target ---
for s, src in enumerate(subjects):
if skip_self and s == t:
continue
pair_tag = f"Tgt{t:04d}_Src{s:04d}"
print(f" Registering {pair_tag}")
om.set_init_img(src["img_model"])
om.set_cond_img(tgt["img_model"].clone().detach())
om.predict(
T=[None, timesteps],
proc_type=condition_type,
)
ddf_comp = om.get_def()
# --- Model-resolution registered image ---
img_rec = om.apply_def(
img=src["img_model"], ddf=ddf_comp, padding_mode="zeros",
)
nib.save(
utils.converet_to_nibabel(img_rec, ndims=ndims),
os.path.join(reg_img_savepath, f"{pair_tag}.nii.gz"),
)
# --- Model-resolution registered label ---
label_rec = None
if src["label_model"] is not None:
label_rec = om.apply_def(
img=src["label_model"], ddf=ddf_comp,
padding_mode="zeros", resample_mode="nearest",
)
nib.save(
utils.converet_to_nibabel(label_rec, ndims=ndims),
os.path.join(reg_msk_savepath, f"{pair_tag}_GT.nii.gz"),
)
# --- Model-resolution DDF ---
nib.save(
utils.converet_to_nibabel(ddf_comp, ndims=ndims),
os.path.join(reg_ddf_savepath, f"{pair_tag}.nii.gz"),
)
# --- Full-resolution registered image ---
img_rec_fullres = om.apply_def(
img=src["img_fullres"], ddf=ddf_comp, padding_mode="border",
)
nib.save(
utils.converet_to_nibabel(img_rec_fullres, ndims=ndims),
os.path.join(reg_img_savepath_fullres, f"{pair_tag}.nii.gz"),
)
# --- Full-resolution registered label ---
label_rec_fullres = None
if src["label_fullres"] is not None:
label_rec_fullres = om.apply_def(
img=src["label_fullres"], ddf=ddf_comp,
padding_mode="zeros", resample_mode="nearest",
)
nib.save(
utils.converet_to_nibabel(label_rec_fullres, ndims=ndims),
os.path.join(reg_msk_savepath_fullres, f"{pair_tag}_GT.nii.gz"),
)
# --- Full-resolution DDF ---
ddf_fullres = F.interpolate(
ddf_comp, size=src["orig_sz"], mode="trilinear", align_corners=False,
)
nib.save(
utils.converet_to_nibabel(ddf_fullres, ndims=ndims),
os.path.join(reg_ddf_savepath_fullres, f"{pair_tag}.nii.gz"),
)
# --- Evaluation metrics (per-class from multi-class label) ---
if (
organ_label_ids
and label_rec_fullres is not None
and tgt["label_fullres"] is not None
):
tgt_label_np = tgt["label_fullres"][0, 0].cpu().numpy()
src_label_np = src["label_fullres"][0, 0].cpu().numpy()
reg_label_np = label_rec_fullres[0, 0].cpu().numpy()
for cid in organ_label_ids:
tgt_mask = (np.round(tgt_label_np) == cid).astype(np.float32)
src_mask = (np.round(src_label_np) == cid).astype(np.float32)
reg_mask = (np.round(reg_label_np) == cid).astype(np.float32)
# Skip if both masks are empty
if np.sum(tgt_mask) == 0 and np.sum(src_mask) == 0:
continue
# Pre-registration: source vs target
pre_dsc = compute_dsc(src_mask, tgt_mask)
pre_asd = compute_asd(src_mask, tgt_mask)
pre_hd = compute_hd(src_mask, tgt_mask)
metrics_pre[cid]["dsc"][(t, s)] = pre_dsc
metrics_pre[cid]["asd"][(t, s)] = pre_asd
metrics_pre[cid]["hd"][(t, s)] = pre_hd
# Post-registration: registered label vs target
post_dsc = compute_dsc(reg_mask, tgt_mask)
post_asd = compute_asd(reg_mask, tgt_mask)
post_hd = compute_hd(reg_mask, tgt_mask)
metrics[cid]["dsc"][(t, s)] = post_dsc
metrics[cid]["asd"][(t, s)] = post_asd
metrics[cid]["hd"][(t, s)] = post_hd
# Print summary for this pair (mean across classes)
post_dscs = [
metrics[cid]["dsc"][(t, s)]
for cid in organ_label_ids
if (t, s) in metrics[cid]["dsc"]
]
if post_dscs:
print(
f" Mean DSC: pre={np.mean([metrics_pre[cid]['dsc'].get((t,s), float('nan')) for cid in organ_label_ids if (t,s) in metrics_pre[cid]['dsc']]):.4f} "
f"post={np.mean(post_dscs):.4f}"
)
print("\nAll-to-all unpaired registration complete.")
# ========== Write evaluation CSVs ==========
n_subj = len(subjects)
def _fmt(val):
if val is None:
return ""
if np.isnan(val):
return "NaN"
return f"{val:.6f}"
# --- Per-class matrix CSVs ---
for cid in organ_label_ids:
prefix = f"label{cid:02d}_"
for metric_name in ["dsc", "asd", "hd"]:
csv_path = os.path.join(eval_dir, f"{prefix}{metric_name}.csv")
with open(csv_path, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(["target \\ source"] + vol_names)
for t_idx in range(n_subj):
row = [vol_names[t_idx]]
for s_idx in range(n_subj):
val = metrics[cid][metric_name].get((t_idx, s_idx))
row.append(_fmt(val))
writer.writerow(row)
print(f"Saved {csv_path}")
# --- Per-class pre-registration matrix CSVs ---
for cid in organ_label_ids:
prefix = f"label{cid:02d}_pre_"
for metric_name in ["dsc", "asd", "hd"]:
csv_path = os.path.join(eval_dir, f"{prefix}{metric_name}.csv")
with open(csv_path, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(["target \\ source"] + vol_names)
for t_idx in range(n_subj):
row = [vol_names[t_idx]]
for s_idx in range(n_subj):
val = metrics_pre[cid][metric_name].get((t_idx, s_idx))
row.append(_fmt(val))
writer.writerow(row)
print(f"Saved {csv_path}")
# --- Overall summary ---
overall_path = os.path.join(eval_dir, "overall.csv")
with open(overall_path, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow([
"label_id", "metric",
"pre_mean", "pre_std",
"post_mean", "post_std",
"n_pairs",
])
for cid in organ_label_ids:
for metric_name in ["dsc", "asd", "hd"]:
pre_vals = [
v for v in metrics_pre[cid][metric_name].values()
if not np.isnan(v)
]
post_vals = [
v for v in metrics[cid][metric_name].values()
if not np.isnan(v)
]
pre_mean = np.mean(pre_vals) if pre_vals else float("nan")
pre_std = np.std(pre_vals) if pre_vals else float("nan")
post_mean = np.mean(post_vals) if post_vals else float("nan")
post_std = np.std(post_vals) if post_vals else float("nan")
n = max(len(pre_vals), len(post_vals))
writer.writerow([
cid,
metric_name.upper(),
_fmt(pre_mean), _fmt(pre_std),
_fmt(post_mean), _fmt(post_std),
n,
])
print(f"Saved {overall_path}")
# --- Grand summary (mean across all classes) ---
grand_path = os.path.join(eval_dir, "grand_summary.csv")
with open(grand_path, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(["metric", "pre_mean", "pre_std", "post_mean", "post_std", "n_classes"])
for metric_name in ["dsc", "asd", "hd"]:
pre_class_means = []
post_class_means = []
for cid in organ_label_ids:
pre_vals = [
v for v in metrics_pre[cid][metric_name].values()
if not np.isnan(v)
]
post_vals = [
v for v in metrics[cid][metric_name].values()
if not np.isnan(v)
]
if pre_vals:
pre_class_means.append(np.mean(pre_vals))
if post_vals:
post_class_means.append(np.mean(post_vals))
writer.writerow([
metric_name.upper(),
_fmt(np.mean(pre_class_means) if pre_class_means else float("nan")),
_fmt(np.std(pre_class_means) if pre_class_means else float("nan")),
_fmt(np.mean(post_class_means) if post_class_means else float("nan")),
_fmt(np.std(post_class_means) if post_class_means else float("nan")),
len(organ_label_ids),
])
print(f"Saved {grand_path}")