""" OM_reg_unpair_ext.py — Unpaired all-to-all registration using OMorpher with an external Learn2Reg-style dataset JSON (e.g. OASIS). Extracts unique subjects from the registration_val or registration_test pairs, then registers every subject to every other subject. Supports multi-class label maps (e.g. 35 brain regions) with auto-discovered label IDs. Saves registered images, masks, DDFs, and evaluation metrics (DSC, ASD, HD) per label class. Usage: python Scripts/OM_reg_unpair_ext.py -C Config/config_reg_brain.yaml \ --dataset-json ~/rds/rds-airr-p51-TWhPgQVLKbA/Code/Registration/Dataset/OASIS/OASIS_dataset.json \ --split val python Scripts/OM_reg_unpair_ext.py -C Config/config_reg_brain.yaml \ --dataset-json ~/rds/rds-airr-p51-TWhPgQVLKbA/Code/Registration/Dataset/OASIS/OASIS_dataset.json \ --split test -N 10 """ import os import sys # Add project root to path so imports work from Scripts/ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import csv import json import numpy as np import torch import torch.nn.functional as F import nibabel as nib import yaml import SimpleITK as sitk from scipy.ndimage import distance_transform_edt, binary_erosion from tqdm import tqdm import utils from Dataloader.dataLoader import reverse_axis_order from OMorpher import OMorpher # ========== CLI ========== import argparse parser = argparse.ArgumentParser() parser.add_argument( "--config", "-C", help="Path for the config file", type=str, default="Config/config_reg_brain.yaml", required=False, ) parser.add_argument( "--dataset-json", help="Path to the Learn2Reg-style dataset JSON", type=str, default="~/rds/rds-airr-p51-TWhPgQVLKbA/Code/Registration/Dataset/OASIS/OASIS_dataset.json", ) parser.add_argument( "--split", help="Which registration split to use: 'val' or 'test'", type=str, choices=["val", "test"], default="val", ) parser.add_argument( "--max-samples", "-N", help="Max number of subjects to include (0 = all)", type=int, default=0, ) args = parser.parse_args() # ========== Config ========== with open(args.config, "r") as file: hyp_parameters = yaml.safe_load(file) print(hyp_parameters) hyp_parameters["batchsize"] = 1 model_img_sz = hyp_parameters["img_size"] timesteps = hyp_parameters["timesteps"] condition_type = hyp_parameters["condition_type"] ndims = hyp_parameters["ndims"] # ========== Load external dataset JSON ========== dataset_json_path = os.path.expanduser(args.dataset_json) dataset_root = os.path.dirname(dataset_json_path) with open(dataset_json_path, "r") as f: dataset_meta = json.load(f) dataset_name = dataset_meta.get("name", "UnknownDataset") print(f"Dataset: {dataset_name}") # Select registration split if args.split == "val": pairs = dataset_meta.get("registration_val", []) elif args.split == "test": pairs = dataset_meta.get("registration_test", []) else: raise ValueError(f"Unknown split: {args.split}") print(f"Split: {args.split}, Pairs in JSON: {len(pairs)}") # Extract unique subject image paths from the pairs _seen_paths = {} for pair in pairs: for role in ("fixed", "moving"): rel = pair[role] if rel not in _seen_paths: _seen_paths[rel] = len(_seen_paths) subject_rel_paths = list(_seen_paths.keys()) if args.max_samples > 0: subject_rel_paths = subject_rel_paths[: args.max_samples] print(f"Unique subjects: {len(subject_rel_paths)} (max_samples={args.max_samples or 'all'})") # Build label lookup: image basename -> label relative path _label_lookup = {} for entry in dataset_meta.get("training", []): img_base = os.path.basename(entry["image"]) _label_lookup[img_base] = entry.get("label") for entry in dataset_meta.get("test", []): img_base = os.path.basename(entry.get("image", "")) if entry.get("label"): _label_lookup[img_base] = entry["label"] # ========== OMorpher setup ========== epoch = f'{hyp_parameters["model_id_str"]}_{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}' model_save_path = os.path.join( f'Models/{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}/', str(epoch) + ".pth", ) print("Loading model from:", model_save_path) om = OMorpher( config=hyp_parameters, checkpoint_path=model_save_path, device=str(hyp_parameters.get("device", "cpu")), ) print(om) # ========== Output directories ========== reg_img_savepath = hyp_parameters["reg_img_savepath"] reg_msk_savepath = hyp_parameters["reg_msk_savepath"] reg_ddf_savepath = hyp_parameters["reg_ddf_savepath"] reg_img_savepath_fullres = reg_img_savepath.rstrip("/") + "_fullres/" reg_msk_savepath_fullres = reg_msk_savepath.rstrip("/") + "_fullres/" reg_ddf_savepath_fullres = reg_ddf_savepath.rstrip("/") + "_fullres/" eval_dir = os.path.join(reg_img_savepath, "..", "eval") for p in [ reg_img_savepath, reg_msk_savepath, reg_ddf_savepath, reg_img_savepath_fullres, reg_msk_savepath_fullres, reg_ddf_savepath_fullres, eval_dir, ]: os.makedirs(p, exist_ok=True) # ========== Settings ========== skip_self = True # skip pairs where source == target # ========== Helper functions ========== def resolve_path(rel_path): """Resolve a relative path from the dataset JSON to an absolute path.""" if os.path.isabs(rel_path): return rel_path return os.path.normpath(os.path.join(dataset_root, rel_path)) def load_volume(nifti_path): """Load a NIfTI volume: axis reorder only. OMorpher._standardize_img handles: normalize -> pad-to-cube -> resize. """ volume = sitk.ReadImage(nifti_path) volume = sitk.GetArrayFromImage(volume) volume = reverse_axis_order(volume) if volume.ndim == 4: volume = volume[:, :, :, 0] return volume def load_label(nifti_path): """Load a NIfTI label map: axis reorder only. OMorpher._standardize_label handles: pad-to-cube -> resize (nearest). """ label = sitk.ReadImage(nifti_path) label = sitk.GetArrayFromImage(label) label = reverse_axis_order(label) if label.ndim > 3: label = label[:, :, :, 0] return label def get_label_path_for_image(image_rel_path): """Find the label path for an image by looking up the training/test entries.""" img_base = os.path.basename(image_rel_path) # Fix extension mismatch: JSON test entries may use .csv but files are .nii.gz for ext in [img_base, img_base.replace(".nii.gz", ".csv")]: label_rel = _label_lookup.get(ext) if label_rel is not None: # Ensure we use .nii.gz extension for the actual file label_rel = label_rel.replace(".csv", ".nii.gz") label_abs = resolve_path(label_rel) if os.path.exists(label_abs): return label_abs # Fallback: derive label path from image path (images* -> labels*) img_abs = resolve_path(image_rel_path) label_abs = img_abs.replace("/images", "/labels") if os.path.exists(label_abs): return label_abs return None def get_volume_name(path): """Extract a short name from a NIfTI file path.""" name = os.path.basename(path) for ext in [".nii.gz", ".nii"]: if name.endswith(ext): name = name[: -len(ext)] break return name # ---------- Auto-discover label IDs ---------- def discover_label_ids(label_path): """Read a label NIfTI and return sorted non-zero unique IDs.""" lab = sitk.ReadImage(label_path) lab = sitk.GetArrayFromImage(lab) unique = np.unique(lab).astype(int) return sorted([int(v) for v in unique if v > 0]) # ---------- Evaluation metrics ---------- def _surface_distances(pred, gt): """Compute directed surface distances between two binary masks.""" pred_bool = pred > 0.5 gt_bool = gt > 0.5 if not np.any(pred_bool) or not np.any(gt_bool): return None, None struct = None pred_surface = pred_bool ^ binary_erosion(pred_bool, structure=struct) gt_surface = gt_bool ^ binary_erosion(gt_bool, structure=struct) if not np.any(pred_surface): pred_surface = pred_bool if not np.any(gt_surface): gt_surface = gt_bool dt_gt = distance_transform_edt(~gt_surface) dt_pred = distance_transform_edt(~pred_surface) return dt_gt[pred_surface], dt_pred[gt_surface] def compute_dsc(pred, gt): """Dice Similarity Coefficient.""" pred_bool = pred > 0.5 gt_bool = gt > 0.5 intersection = np.sum(pred_bool & gt_bool) denom = np.sum(pred_bool) + np.sum(gt_bool) if denom == 0: return 1.0 return 2.0 * float(intersection) / float(denom) def compute_asd(pred, gt): """Average (symmetric) Surface Distance.""" d1, d2 = _surface_distances(pred, gt) if d1 is None: return float("nan") return (np.mean(d1) + np.mean(d2)) / 2.0 def compute_hd(pred, gt): """Hausdorff Distance (maximum of directed HDs).""" d1, d2 = _surface_distances(pred, gt) if d1 is None: return float("nan") return float(max(np.max(d1), np.max(d2))) # ========== Pre-load all subjects ========== subjects = [] organ_label_ids = None # auto-discovered from first label file for rel_path in tqdm(subject_rel_paths, desc="Loading subjects"): abs_path = resolve_path(rel_path) vol = load_volume(abs_path) om.set_init_img(vol) img_model = om._init_img.clone() img_fullres = om._init_img_raw.clone() orig_sz = list(img_fullres.shape[2:]) # Load label (single-channel multi-class map) label_path = get_label_path_for_image(rel_path) label_model, label_fullres = None, None if label_path is not None and os.path.exists(label_path): lab = load_label(label_path) label_model, label_fullres = om._standardize_label(lab) # Auto-discover label IDs from the first available label if organ_label_ids is None: organ_label_ids = discover_label_ids(label_path) print(f"Auto-discovered {len(organ_label_ids)} label classes: {organ_label_ids}") subjects.append({ "rel_path": rel_path, "img_model": img_model, "img_fullres": img_fullres, "label_model": label_model, "label_fullres": label_fullres, "orig_sz": orig_sz, }) print(f"Loaded {len(subjects)} subjects into memory.") if organ_label_ids is None: organ_label_ids = [] print("No labels found — skipping evaluation metrics.") else: print(f"Organ labels for evaluation: {organ_label_ids}") # ========== Prepare evaluation structures ========== vol_names = [get_volume_name(subj["rel_path"]) for subj in subjects] # Disambiguate duplicate basenames by appending index _seen = {} for i, vn in enumerate(vol_names): _seen.setdefault(vn, []).append(i) for vn, indices in _seen.items(): if len(indices) > 1: for idx in indices: vol_names[idx] = f"{vn}_{idx}" # metrics[label_id][metric_name][(t, s)] = value (post-registration) metrics = { cid: {"dsc": {}, "asd": {}, "hd": {}} for cid in organ_label_ids } # metrics_pre: same structure for pre-registration metrics_pre = { cid: {"dsc": {}, "asd": {}, "hd": {}} for cid in organ_label_ids } # ========== All-to-all registration ========== with torch.no_grad(): for t, tgt in enumerate(tqdm(subjects, desc="Targets")): tgt_tag = f"Tgt{t:04d}" # --- Save target original at model resolution --- nib.save( utils.converet_to_nibabel(tgt["img_model"], ndims=ndims), os.path.join(reg_img_savepath, f"{tgt_tag}_ORG.nii.gz"), ) if tgt["label_model"] is not None: nib.save( utils.converet_to_nibabel(tgt["label_model"], ndims=ndims), os.path.join(reg_msk_savepath, f"{tgt_tag}_ORG_GT.nii.gz"), ) # --- Save target original at full resolution --- nib.save( utils.converet_to_nibabel(tgt["img_fullres"], ndims=ndims), os.path.join(reg_img_savepath_fullres, f"{tgt_tag}_ORG.nii.gz"), ) if tgt["label_fullres"] is not None: nib.save( utils.converet_to_nibabel(tgt["label_fullres"], ndims=ndims), os.path.join(reg_msk_savepath_fullres, f"{tgt_tag}_ORG_GT.nii.gz"), ) # --- Inner loop: register each source to this target --- for s, src in enumerate(subjects): if skip_self and s == t: continue pair_tag = f"Tgt{t:04d}_Src{s:04d}" print(f" Registering {pair_tag}") om.set_init_img(src["img_model"]) om.set_cond_img(tgt["img_model"].clone().detach()) om.predict( T=[None, timesteps], proc_type=condition_type, ) ddf_comp = om.get_def() # --- Model-resolution registered image --- img_rec = om.apply_def( img=src["img_model"], ddf=ddf_comp, padding_mode="zeros", ) nib.save( utils.converet_to_nibabel(img_rec, ndims=ndims), os.path.join(reg_img_savepath, f"{pair_tag}.nii.gz"), ) # --- Model-resolution registered label --- label_rec = None if src["label_model"] is not None: label_rec = om.apply_def( img=src["label_model"], ddf=ddf_comp, padding_mode="zeros", resample_mode="nearest", ) nib.save( utils.converet_to_nibabel(label_rec, ndims=ndims), os.path.join(reg_msk_savepath, f"{pair_tag}_GT.nii.gz"), ) # --- Model-resolution DDF --- nib.save( utils.converet_to_nibabel(ddf_comp, ndims=ndims), os.path.join(reg_ddf_savepath, f"{pair_tag}.nii.gz"), ) # --- Full-resolution registered image --- img_rec_fullres = om.apply_def( img=src["img_fullres"], ddf=ddf_comp, padding_mode="border", ) nib.save( utils.converet_to_nibabel(img_rec_fullres, ndims=ndims), os.path.join(reg_img_savepath_fullres, f"{pair_tag}.nii.gz"), ) # --- Full-resolution registered label --- label_rec_fullres = None if src["label_fullres"] is not None: label_rec_fullres = om.apply_def( img=src["label_fullres"], ddf=ddf_comp, padding_mode="zeros", resample_mode="nearest", ) nib.save( utils.converet_to_nibabel(label_rec_fullres, ndims=ndims), os.path.join(reg_msk_savepath_fullres, f"{pair_tag}_GT.nii.gz"), ) # --- Full-resolution DDF --- ddf_fullres = F.interpolate( ddf_comp, size=src["orig_sz"], mode="trilinear", align_corners=False, ) nib.save( utils.converet_to_nibabel(ddf_fullres, ndims=ndims), os.path.join(reg_ddf_savepath_fullres, f"{pair_tag}.nii.gz"), ) # --- Evaluation metrics (per-class from multi-class label) --- if ( organ_label_ids and label_rec_fullres is not None and tgt["label_fullres"] is not None ): tgt_label_np = tgt["label_fullres"][0, 0].cpu().numpy() src_label_np = src["label_fullres"][0, 0].cpu().numpy() reg_label_np = label_rec_fullres[0, 0].cpu().numpy() for cid in organ_label_ids: tgt_mask = (np.round(tgt_label_np) == cid).astype(np.float32) src_mask = (np.round(src_label_np) == cid).astype(np.float32) reg_mask = (np.round(reg_label_np) == cid).astype(np.float32) # Skip if both masks are empty if np.sum(tgt_mask) == 0 and np.sum(src_mask) == 0: continue # Pre-registration: source vs target pre_dsc = compute_dsc(src_mask, tgt_mask) pre_asd = compute_asd(src_mask, tgt_mask) pre_hd = compute_hd(src_mask, tgt_mask) metrics_pre[cid]["dsc"][(t, s)] = pre_dsc metrics_pre[cid]["asd"][(t, s)] = pre_asd metrics_pre[cid]["hd"][(t, s)] = pre_hd # Post-registration: registered label vs target post_dsc = compute_dsc(reg_mask, tgt_mask) post_asd = compute_asd(reg_mask, tgt_mask) post_hd = compute_hd(reg_mask, tgt_mask) metrics[cid]["dsc"][(t, s)] = post_dsc metrics[cid]["asd"][(t, s)] = post_asd metrics[cid]["hd"][(t, s)] = post_hd # Print summary for this pair (mean across classes) post_dscs = [ metrics[cid]["dsc"][(t, s)] for cid in organ_label_ids if (t, s) in metrics[cid]["dsc"] ] if post_dscs: print( f" Mean DSC: pre={np.mean([metrics_pre[cid]['dsc'].get((t,s), float('nan')) for cid in organ_label_ids if (t,s) in metrics_pre[cid]['dsc']]):.4f} " f"post={np.mean(post_dscs):.4f}" ) print("\nAll-to-all unpaired registration complete.") # ========== Write evaluation CSVs ========== n_subj = len(subjects) def _fmt(val): if val is None: return "" if np.isnan(val): return "NaN" return f"{val:.6f}" # --- Per-class matrix CSVs --- for cid in organ_label_ids: prefix = f"label{cid:02d}_" for metric_name in ["dsc", "asd", "hd"]: csv_path = os.path.join(eval_dir, f"{prefix}{metric_name}.csv") with open(csv_path, "w", newline="") as f: writer = csv.writer(f) writer.writerow(["target \\ source"] + vol_names) for t_idx in range(n_subj): row = [vol_names[t_idx]] for s_idx in range(n_subj): val = metrics[cid][metric_name].get((t_idx, s_idx)) row.append(_fmt(val)) writer.writerow(row) print(f"Saved {csv_path}") # --- Per-class pre-registration matrix CSVs --- for cid in organ_label_ids: prefix = f"label{cid:02d}_pre_" for metric_name in ["dsc", "asd", "hd"]: csv_path = os.path.join(eval_dir, f"{prefix}{metric_name}.csv") with open(csv_path, "w", newline="") as f: writer = csv.writer(f) writer.writerow(["target \\ source"] + vol_names) for t_idx in range(n_subj): row = [vol_names[t_idx]] for s_idx in range(n_subj): val = metrics_pre[cid][metric_name].get((t_idx, s_idx)) row.append(_fmt(val)) writer.writerow(row) print(f"Saved {csv_path}") # --- Overall summary --- overall_path = os.path.join(eval_dir, "overall.csv") with open(overall_path, "w", newline="") as f: writer = csv.writer(f) writer.writerow([ "label_id", "metric", "pre_mean", "pre_std", "post_mean", "post_std", "n_pairs", ]) for cid in organ_label_ids: for metric_name in ["dsc", "asd", "hd"]: pre_vals = [ v for v in metrics_pre[cid][metric_name].values() if not np.isnan(v) ] post_vals = [ v for v in metrics[cid][metric_name].values() if not np.isnan(v) ] pre_mean = np.mean(pre_vals) if pre_vals else float("nan") pre_std = np.std(pre_vals) if pre_vals else float("nan") post_mean = np.mean(post_vals) if post_vals else float("nan") post_std = np.std(post_vals) if post_vals else float("nan") n = max(len(pre_vals), len(post_vals)) writer.writerow([ cid, metric_name.upper(), _fmt(pre_mean), _fmt(pre_std), _fmt(post_mean), _fmt(post_std), n, ]) print(f"Saved {overall_path}") # --- Grand summary (mean across all classes) --- grand_path = os.path.join(eval_dir, "grand_summary.csv") with open(grand_path, "w", newline="") as f: writer = csv.writer(f) writer.writerow(["metric", "pre_mean", "pre_std", "post_mean", "post_std", "n_classes"]) for metric_name in ["dsc", "asd", "hd"]: pre_class_means = [] post_class_means = [] for cid in organ_label_ids: pre_vals = [ v for v in metrics_pre[cid][metric_name].values() if not np.isnan(v) ] post_vals = [ v for v in metrics[cid][metric_name].values() if not np.isnan(v) ] if pre_vals: pre_class_means.append(np.mean(pre_vals)) if post_vals: post_class_means.append(np.mean(post_vals)) writer.writerow([ metric_name.upper(), _fmt(np.mean(pre_class_means) if pre_class_means else float("nan")), _fmt(np.std(pre_class_means) if pre_class_means else float("nan")), _fmt(np.mean(post_class_means) if post_class_means else float("nan")), _fmt(np.std(post_class_means) if post_class_means else float("nan")), len(organ_label_ids), ]) print(f"Saved {grand_path}")