""" OM_reg_unpair.py — Unpaired all-to-all registration using OMorpher. Registers every subject to every other subject in a nested loop. Output naming uses Tgt/Src prefixes instead of the Patient/Slice barcode. Computes DSC, ASD, HD for organ labels (excludes tumour/lesion labels) and saves per-pair tables and summary statistics as CSVs. Usage: python Scripts/OM_reg_unpair.py -C Config/config_om.yaml """ import os import sys # Add project root to path so imports work from Scripts/ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import csv import numpy as np import torch import torch.nn.functional as F import nibabel as nib import yaml import SimpleITK as sitk from scipy.ndimage import distance_transform_edt, binary_erosion from tqdm import tqdm import utils from Dataloader.dataLoader import OminiDataset_inference_w_all, reverse_axis_order from OMorpher import OMorpher # ========== CLI ========== import argparse parser = argparse.ArgumentParser() parser.add_argument( "--config", "-C", help="Path for the config file", type=str, default="Config/config_om.yaml", required=False, ) parser.add_argument( "--max-samples", "-N", help="Max number of subjects to include (0 = all)", type=int, default=0, ) args = parser.parse_args() # ========== Config ========== with open(args.config, "r") as file: hyp_parameters = yaml.safe_load(file) print(hyp_parameters) hyp_parameters["batchsize"] = 1 model_img_sz = hyp_parameters["img_size"] timesteps = hyp_parameters["timesteps"] condition_type = hyp_parameters["condition_type"] ndims = hyp_parameters["ndims"] # ========== Dataset ========== label_keys = ["brain"] database = ["Brats2019"] dataset = OminiDataset_inference_w_all( transform=None, min_crop_ratio=1.0, label_key=label_keys, database=database, ) # ========== OMorpher setup ========== epoch = f'{hyp_parameters["model_id_str"]}_{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}' model_save_path = os.path.join( f'Models/{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}/', str(epoch) + ".pth", ) print("Loading model from:", model_save_path) om = OMorpher( config=hyp_parameters, checkpoint_path=model_save_path, device=str(hyp_parameters.get("device", "cpu")), ) print(om) # ========== Output directories ========== reg_img_savepath = hyp_parameters["reg_img_savepath"] reg_msk_savepath = hyp_parameters["reg_msk_savepath"] reg_ddf_savepath = hyp_parameters["reg_ddf_savepath"] reg_img_savepath_fullres = reg_img_savepath.rstrip("/") + "_fullres/" reg_msk_savepath_fullres = reg_msk_savepath.rstrip("/") + "_fullres/" reg_ddf_savepath_fullres = reg_ddf_savepath.rstrip("/") + "_fullres/" eval_dir = os.path.join(reg_img_savepath, "..", "eval") for p in [ reg_img_savepath, reg_msk_savepath, reg_ddf_savepath, reg_img_savepath_fullres, reg_msk_savepath_fullres, reg_ddf_savepath_fullres, eval_dir, ]: os.makedirs(p, exist_ok=True) # ========== Settings ========== skip_self = True # skip pairs where source == target # Labels that are NOT organs — excluded from metric evaluation. # BraTS "brain" is actually a tumour segmentation (non-enhancing tumour, edema, # enhancing tumour), not a whole-brain mask, so it must be excluded. EXCLUDE_LABELS = { "brain", # BraTS tumour segmentation "tumor", # PSMA-CT / PSMA-FDG tumour "tumour", "noisy", # Kaggle OSIC artefact mask } # Any label containing these substrings is also excluded EXCLUDE_SUBSTRINGS = {"lesion", "tumor", "tumour"} # ========== Helper functions ========== def center_pad_to_cube(volume): """Pad volume to a cube using the max dimension, with symmetric (center) padding.""" max_dim = max(volume.shape[:3]) pad_width = [] for s in volume.shape[:3]: total_pad = max_dim - s pad_before = total_pad // 2 pad_after = total_pad - pad_before pad_width.append((pad_before, pad_after)) for _ in range(volume.ndim - 3): pad_width.append((0, 0)) return np.pad(volume, pad_width, mode="constant", constant_values=0) def load_fullres_volume(key, ds): """Load original-resolution volume: axis reorder, clamp, normalize, center-pad to cube.""" volume = sitk.ReadImage(key) volume = sitk.GetArrayFromImage(volume) volume = reverse_axis_order(volume) if volume.ndim == 4: channel_ids = ds.get_channel_ids(key) channel_id = channel_ids[0] if len(channel_ids) > 0 else 0 volume = volume[:, :, :, channel_id] if ds.clamp_range is not None: modality = ds.ALLdata_filtered[key].get("Modality", None) if modality == "CT": volume = np.clip(volume, ds.clamp_range[0], ds.clamp_range[1]) volume = ds.normalize(volume) volume = center_pad_to_cube(volume) return volume def load_fullres_label(key, ds, label_key): """Load original-resolution label: axis reorder, center-pad to cube.""" label_path_dict = ds.ALLdata_filtered[key].get("Label_path", {}) task_labels = label_path_dict.get("segmentation", {}) if label_key not in task_labels: return None label = sitk.ReadImage(task_labels[label_key]) label = sitk.GetArrayFromImage(label) label = reverse_axis_order(label) if label.ndim > 3: channel_ids = ds.get_channel_ids(key) if len(channel_ids) != 0: label = label[..., channel_ids] label = center_pad_to_cube(label) return label def get_volume_name(key): """Extract a short name from a NIfTI file path.""" name = os.path.basename(key) for ext in [".nii.gz", ".nii"]: if name.endswith(ext): name = name[: -len(ext)] break return name def is_organ_label(label_key): """Return True if label_key is an organ (not tumour/lesion).""" lk_lower = label_key.lower() if lk_lower in EXCLUDE_LABELS: return False return not any(kw in lk_lower for kw in EXCLUDE_SUBSTRINGS) # ---------- Evaluation metrics ---------- def _surface_distances(pred, gt): """Compute directed surface distances between two binary masks. Returns (dist_pred_to_gt, dist_gt_to_pred) arrays, or (None, None) if either mask is empty or has no extractable surface. """ pred_bool = pred > 0.5 gt_bool = gt > 0.5 if not np.any(pred_bool) or not np.any(gt_bool): return None, None # Extract surface voxels via erosion struct = None # default 3x3(x3) cross connectivity pred_surface = pred_bool ^ binary_erosion(pred_bool, structure=struct) gt_surface = gt_bool ^ binary_erosion(gt_bool, structure=struct) # Fallback: single-voxel regions lose their surface after erosion if not np.any(pred_surface): pred_surface = pred_bool if not np.any(gt_surface): gt_surface = gt_bool dt_gt = distance_transform_edt(~gt_surface) dt_pred = distance_transform_edt(~pred_surface) return dt_gt[pred_surface], dt_pred[gt_surface] def compute_dsc(pred, gt): """Dice Similarity Coefficient.""" pred_bool = pred > 0.5 gt_bool = gt > 0.5 intersection = np.sum(pred_bool & gt_bool) denom = np.sum(pred_bool) + np.sum(gt_bool) if denom == 0: return 1.0 # both empty — perfect agreement return 2.0 * float(intersection) / float(denom) def compute_asd(pred, gt): """Average (symmetric) Surface Distance.""" d1, d2 = _surface_distances(pred, gt) if d1 is None: return float("nan") return (np.mean(d1) + np.mean(d2)) / 2.0 def compute_hd(pred, gt): """Hausdorff Distance (maximum of directed HDs).""" d1, d2 = _surface_distances(pred, gt) if d1 is None: return float("nan") return float(max(np.max(d1), np.max(d2))) def compute_negdetj_pct(ddf, ndims=3): """Percent of voxels with negative Jacobian determinant. Args: ddf: displacement field tensor [1, ndims, ...] or numpy array. ndims: 2 or 3. Returns: Percentage of voxels where det(Jacobian) < 0. """ if isinstance(ddf, torch.Tensor): ddf = ddf.detach().cpu().numpy() if ddf.ndim == ndims + 2: ddf = ddf[0] # remove batch dim -> [C, ...] if ndims == 3: dux_dx = np.diff(ddf[0], axis=0, append=ddf[0, -1:, :, :]) duy_dx = np.diff(ddf[1], axis=0, append=ddf[1, -1:, :, :]) duz_dx = np.diff(ddf[2], axis=0, append=ddf[2, -1:, :, :]) dux_dy = np.diff(ddf[0], axis=1, append=ddf[0, :, -1:, :]) duy_dy = np.diff(ddf[1], axis=1, append=ddf[1, :, -1:, :]) duz_dy = np.diff(ddf[2], axis=1, append=ddf[2, :, -1:, :]) dux_dz = np.diff(ddf[0], axis=2, append=ddf[0, :, :, -1:]) duy_dz = np.diff(ddf[1], axis=2, append=ddf[1, :, :, -1:]) duz_dz = np.diff(ddf[2], axis=2, append=ddf[2, :, :, -1:]) j11 = 1.0 + dux_dx; j12 = dux_dy; j13 = dux_dz j21 = duy_dx; j22 = 1.0 + duy_dy; j23 = duy_dz j31 = duz_dx; j32 = duz_dy; j33 = 1.0 + duz_dz detj = ( j11 * (j22 * j33 - j23 * j32) - j12 * (j21 * j33 - j23 * j31) + j13 * (j21 * j32 - j22 * j31) ) elif ndims == 2: dux_dx = np.diff(ddf[0], axis=0, append=ddf[0, -1:, :]) duy_dx = np.diff(ddf[1], axis=0, append=ddf[1, -1:, :]) dux_dy = np.diff(ddf[0], axis=1, append=ddf[0, :, -1:]) duy_dy = np.diff(ddf[1], axis=1, append=ddf[1, :, -1:]) detj = (1.0 + dux_dx) * (1.0 + duy_dy) - dux_dy * duy_dx else: raise ValueError(f"Unsupported ndims={ndims}") n_neg = np.sum(detj < 0) n_total = detj.size return 100.0 * float(n_neg) / float(n_total) # ========== Pre-load all subjects ========== keys = list(dataset.ALLdata_filtered.keys()) if args.max_samples > 0: keys = keys[: args.max_samples] print(f"Total subjects: {len(keys)} (max_samples={args.max_samples or 'all'})") subjects = [] for key in tqdm(keys, desc="Loading subjects"): fullres_vol = load_fullres_volume(key, dataset) om.set_init_img(fullres_vol) img_model = om._init_img.clone() img_fullres = om._init_img_raw.clone() orig_sz = list(img_fullres.shape[2:]) masks_model = [] masks_fullres = [] for lk in label_keys: lab = load_fullres_label(key, dataset, lk) model_t, fullres_t = om._standardize_label(lab) masks_model.append(model_t) masks_fullres.append(fullres_t) if masks_model: mask_model = torch.cat(masks_model, dim=1) mask_fullres = torch.cat(masks_fullres, dim=1) else: mask_model = None mask_fullres = None subjects.append({ "key": key, "img_model": img_model, "img_fullres": img_fullres, "mask_model": mask_model, "mask_fullres": mask_fullres, "orig_sz": orig_sz, }) print(f"Loaded {len(subjects)} subjects into memory.") # ========== Prepare evaluation structures ========== vol_names = [get_volume_name(subj["key"]) for subj in subjects] # Disambiguate duplicate basenames by appending index _seen = {} for i, vn in enumerate(vol_names): _seen.setdefault(vn, []).append(i) for vn, indices in _seen.items(): if len(indices) > 1: for idx in indices: vol_names[idx] = f"{vn}_{idx}" organ_label_indices = [] # (channel_index, label_key) for organ labels only for c, lk in enumerate(label_keys): if is_organ_label(lk): organ_label_indices.append((c, lk)) if organ_label_indices: print(f"Organ labels for evaluation: {[lk for _, lk in organ_label_indices]}") else: print("No organ labels found — skipping evaluation metrics.") # metrics[label_key][metric_name][(t, s)] = value metrics = { lk: {"dsc": {}, "asd": {}, "hd": {}} for _, lk in organ_label_indices } # Per-pair DDF quality metric negdetj_pct = {} # (t, s) -> percentage of negative Jacobian determinant # ========== All-to-all registration ========== with torch.no_grad(): for t, tgt in enumerate(tqdm(subjects, desc="Targets")): tgt_tag = f"Tgt{t:04d}" # --- Save target original at model resolution --- nib.save( utils.converet_to_nibabel(tgt["img_model"], ndims=ndims), os.path.join(reg_img_savepath, f"{tgt_tag}_ORG.nii.gz"), ) if tgt["mask_model"] is not None: nib.save( utils.converet_to_nibabel(tgt["mask_model"], ndims=ndims), os.path.join(reg_msk_savepath, f"{tgt_tag}_ORG_GT.nii.gz"), ) # --- Save target original at full resolution --- nib.save( utils.converet_to_nibabel(tgt["img_fullres"], ndims=ndims), os.path.join(reg_img_savepath_fullres, f"{tgt_tag}_ORG.nii.gz"), ) if tgt["mask_fullres"] is not None: nib.save( utils.converet_to_nibabel(tgt["mask_fullres"], ndims=ndims), os.path.join(reg_msk_savepath_fullres, f"{tgt_tag}_ORG_GT.nii.gz"), ) # --- Inner loop: register each source to this target --- for s, src in enumerate(subjects): if skip_self and s == t: continue pair_tag = f"Tgt{t:04d}_Src{s:04d}" print(f" Registering {pair_tag}") om.set_init_img(src["img_model"]) om.set_cond_img(tgt["img_model"].clone().detach()) om.predict( T=[None, timesteps], proc_type=condition_type, ) ddf_comp = om.get_def() # --- DDF quality: percent negative Jacobian determinant --- neg_pct = compute_negdetj_pct(ddf_comp, ndims=ndims) negdetj_pct[(t, s)] = neg_pct print(f" %|J|<0 = {neg_pct:.4f}%") # --- Model-resolution registered image --- img_rec = om.apply_def( img=src["img_model"], ddf=ddf_comp, padding_mode="zeros", ) nib.save( utils.converet_to_nibabel(img_rec, ndims=ndims), os.path.join(reg_img_savepath, f"{pair_tag}.nii.gz"), ) # --- Model-resolution registered mask --- msk_rec = None if src["mask_model"] is not None: msk_rec = om.apply_def( img=src["mask_model"], ddf=ddf_comp, padding_mode="zeros", resample_mode="nearest", ) nib.save( utils.converet_to_nibabel(msk_rec, ndims=ndims), os.path.join(reg_msk_savepath, f"{pair_tag}_GT.nii.gz"), ) # --- Model-resolution DDF --- nib.save( utils.converet_to_nibabel(ddf_comp, ndims=ndims), os.path.join(reg_ddf_savepath, f"{pair_tag}.nii.gz"), ) # --- Full-resolution registered image --- img_rec_fullres = om.apply_def( img=src["img_fullres"], ddf=ddf_comp, padding_mode="border", ) nib.save( utils.converet_to_nibabel(img_rec_fullres, ndims=ndims), os.path.join(reg_img_savepath_fullres, f"{pair_tag}.nii.gz"), ) # --- Full-resolution registered mask --- msk_rec_fullres = None if src["mask_fullres"] is not None: msk_rec_fullres = om.apply_def( img=src["mask_fullres"], ddf=ddf_comp, padding_mode="zeros", resample_mode="nearest", ) nib.save( utils.converet_to_nibabel(msk_rec_fullres, ndims=ndims), os.path.join(reg_msk_savepath_fullres, f"{pair_tag}_GT.nii.gz"), ) # --- Full-resolution DDF --- ddf_fullres = F.interpolate( ddf_comp, size=src["orig_sz"], mode="trilinear", align_corners=False, ) nib.save( utils.converet_to_nibabel(ddf_fullres, ndims=ndims), os.path.join(reg_ddf_savepath_fullres, f"{pair_tag}.nii.gz"), ) # --- Evaluation metrics (full-res organ labels) --- if ( organ_label_indices and msk_rec_fullres is not None and tgt["mask_fullres"] is not None ): for c, lk in organ_label_indices: tgt_mask_np = tgt["mask_fullres"][0, c].cpu().numpy() reg_mask_np = msk_rec_fullres[0, c].cpu().numpy() # Skip placeholder masks (fill_value = -1) if np.all(tgt_mask_np < 0) or np.all(reg_mask_np < 0): continue dsc_val = compute_dsc(reg_mask_np, tgt_mask_np) asd_val = compute_asd(reg_mask_np, tgt_mask_np) hd_val = compute_hd(reg_mask_np, tgt_mask_np) metrics[lk]["dsc"][(t, s)] = dsc_val metrics[lk]["asd"][(t, s)] = asd_val metrics[lk]["hd"][(t, s)] = hd_val print( f" [{lk}] DSC={dsc_val:.4f} " f"ASD={asd_val:.2f} HD={hd_val:.2f}" ) print("All-to-all registration complete.") # ========== Write evaluation CSVs ========== n_subj = len(subjects) # --- %|J|<0 matrix CSV --- negdetj_csv_path = os.path.join(eval_dir, "negdetj_pct.csv") with open(negdetj_csv_path, "w", newline="") as f: writer = csv.writer(f) writer.writerow(["target \\ source"] + vol_names) for t_idx in range(n_subj): row = [vol_names[t_idx]] for s_idx in range(n_subj): val = negdetj_pct.get((t_idx, s_idx)) if val is None: row.append("") else: row.append(f"{val:.6f}") writer.writerow(row) print(f"Saved {negdetj_csv_path}") for c, lk in organ_label_indices: # Use label prefix only when there are multiple organ labels prefix = f"{lk}_" if len(organ_label_indices) > 1 else "" for metric_name in ["dsc", "asd", "hd"]: csv_path = os.path.join(eval_dir, f"{prefix}{metric_name}.csv") with open(csv_path, "w", newline="") as f: writer = csv.writer(f) # Header row: empty corner cell + source volume names writer.writerow(["target \\ source"] + vol_names) for t in range(n_subj): row = [vol_names[t]] for s in range(n_subj): val = metrics[lk][metric_name].get((t, s)) if val is None: row.append("") # self-pair or missing elif np.isnan(val): row.append("NaN") else: row.append(f"{val:.6f}") writer.writerow(row) print(f"Saved {csv_path}") # --- Overall summary --- overall_path = os.path.join(eval_dir, "overall.csv") with open(overall_path, "w", newline="") as f: writer = csv.writer(f) writer.writerow(["label", "metric", "mean", "std", "n_pairs"]) # %|J|<0 summary (not per-label) negdetj_vals = [v for v in negdetj_pct.values() if not np.isnan(v)] if negdetj_vals: writer.writerow([ "ALL", "%|J|<0", f"{np.mean(negdetj_vals):.6f}", f"{np.std(negdetj_vals):.6f}", len(negdetj_vals), ]) for _, lk in organ_label_indices: for metric_name in ["dsc", "asd", "hd"]: vals = [ v for v in metrics[lk][metric_name].values() if not np.isnan(v) ] if vals: mean_val = np.mean(vals) std_val = np.std(vals) n_pairs = len(vals) else: mean_val = float("nan") std_val = float("nan") n_pairs = 0 writer.writerow([ lk, metric_name.upper(), f"{mean_val:.6f}" if not np.isnan(mean_val) else "NaN", f"{std_val:.6f}" if not np.isnan(std_val) else "NaN", n_pairs, ]) print(f"Saved {overall_path}")