""" Evaluation script for baseline models. Uses the exact same metrics as cimd/evaluate.py: CI Score, Combined Sensitivity, D_max, Diversity Agreement, GED, Avg Dice, Avg IoU, Avg Hausdorff Usage: python evaluate.py --samples_dir ./results/prob_unet/samples_n16/samples \ --gt_data_dir /workspace/multiannotator_dataset \ --results_file ./results/prob_unet/evaluation_results.csv """ import argparse import glob import os import sys import numpy as np import pandas as pd from PIL import Image from tqdm import tqdm from scipy.spatial.distance import directed_hausdorff # Add baselines root to path sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from dataset import MultiAnnotatorDataset # --------------------------------------------------------------------------- # Metrics (identical to cimd/evaluate.py) # --------------------------------------------------------------------------- def dice_coefficient(pred, target): smooth = 1e-6 pred, target = pred.astype(bool), target.astype(bool) if not np.any(pred) and not np.any(target): return 1.0 intersection = np.sum(pred & target) return (2. * intersection + smooth) / (np.sum(pred) + np.sum(target) + smooth) def iou_score(pred, target): smooth = 1e-6 pred, target = pred.astype(bool), target.astype(bool) if not np.any(pred) and not np.any(target): return 1.0 intersection = np.sum(pred & target) union = np.sum(pred | target) return (intersection + smooth) / (union + smooth) def hausdorff_distance(pred, target): pred_points = np.argwhere(pred) target_points = np.argwhere(target) if len(pred_points) == 0 or len(target_points) == 0: return np.nan return max(directed_hausdorff(pred_points, target_points)[0], directed_hausdorff(target_points, pred_points)[0]) def calculate_combined_sensitivity(samples, gts): combined_sample = np.logical_or.reduce([s.astype(bool) for s in samples]) combined_gt = np.logical_or.reduce([g.astype(bool) for g in gts]) if not np.any(combined_gt): return 1.0 tp = np.sum(combined_sample & combined_gt) return tp / np.sum(combined_gt) def calculate_d_max(samples, gts): if not gts: return 0.0 max_dice_scores = [] for gt in gts: dices = [dice_coefficient(s, gt) for s in samples] if samples else [0.0] max_dice_scores.append(np.max(dices)) return np.mean(max_dice_scores) def calculate_diversity_agreement(samples, gts): def get_variances(masks): if len(masks) < 2: return 0, 0 scores = [1.0 - dice_coefficient(masks[i], masks[j]) for i in range(len(masks)) for j in range(i + 1, len(masks))] return (np.min(scores) if scores else 0, np.max(scores) if scores else 0) V_min_gt, V_max_gt = get_variances(gts) V_min_sample, V_max_sample = get_variances(samples) return 1.0 - (abs(V_min_gt - V_min_sample) + abs(V_max_gt - V_max_sample)) / 2.0 def calculate_ci_score(samples, gts): Sc = calculate_combined_sensitivity(samples, gts) Dmax = calculate_d_max(samples, gts) Da = calculate_diversity_agreement(samples, gts) denominator = Sc + Dmax + Da ci = (3 * Sc * Dmax * Da) / (denominator + 1e-8) return ci, Sc, Dmax, Da def calculate_ged(samples, gts): dist_fn = lambda x, y: 1.0 - iou_score(x, y) if not samples or not gts: return np.nan d_st = np.mean([dist_fn(s, g) for s in samples for g in gts]) d_ss = (np.mean([dist_fn(samples[i], samples[j]) for i in range(len(samples)) for j in range(i + 1, len(samples))]) if len(samples) > 1 else 0) d_gg = (np.mean([dist_fn(gts[i], gts[j]) for i in range(len(gts)) for j in range(i + 1, len(gts))]) if len(gts) > 1 else 0) return 2 * d_st - d_ss - d_gg def load_mask(path): with Image.open(path) as img: return np.array(img.convert("L")) > 127 # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main(): parser = argparse.ArgumentParser( description="Evaluate segmentation samples with CI/GED metrics.") parser.add_argument("--samples_dir", type=str, required=True, help="Directory with saved sample masks.") parser.add_argument("--gt_data_dir", type=str, required=True, help="Root of multiannotator dataset.") parser.add_argument("--image_size", type=int, default=128) parser.add_argument("--results_file", type=str, default="evaluation_results.csv") args = parser.parse_args() # Load ground truth test set gt_dataset = MultiAnnotatorDataset( data_root=args.gt_data_dir, split="test", image_size=args.image_size) all_results = [] print(f"Evaluating samples from {args.samples_dir}...") for i in tqdm(range(len(gt_dataset)), desc="Evaluating"): _, gts_tensor, image_id = gt_dataset[i] # Find sample files for this image sample_paths = sorted(glob.glob( os.path.join(args.samples_dir, f"{image_id}_sample_*.png"))) if not sample_paths: continue samples = [load_mask(p) for p in sample_paths] gts = [gt.numpy() for gt in gts_tensor] ci_score, sc, dmax, da = calculate_ci_score(samples, gts) ged = calculate_ged(samples, gts) all_dice = [dice_coefficient(s, g) for s in samples for g in gts] all_iou = [iou_score(s, g) for s in samples for g in gts] all_hd = [d for d in [hausdorff_distance(s, g) for s in samples for g in gts] if not np.isnan(d)] all_results.append({ "image_id": image_id, "num_samples": len(samples), "num_gts": len(gts), "CI_Score": ci_score, "Combined_Sensitivity": sc, "D_max": dmax, "Diversity_Agreement": da, "GED": ged, "Avg_Dice": np.mean(all_dice) if all_dice else 0, "Avg_IoU": np.mean(all_iou) if all_iou else 0, "Avg_Hausdorff": np.mean(all_hd) if all_hd else np.nan, }) if not all_results: print("No results generated. Check paths and filenames.") return df = pd.DataFrame(all_results) column_order = [ "image_id", "num_samples", "num_gts", "CI_Score", "D_max", "Combined_Sensitivity", "Diversity_Agreement", "GED", "Avg_Dice", "Avg_IoU", "Avg_Hausdorff" ] df = df[column_order] avg_row = df.mean(numeric_only=True).to_frame().T avg_row["image_id"] = "AVERAGE" df_final = pd.concat([df, avg_row], ignore_index=True) df_final.to_csv(args.results_file, index=False, float_format="%.4f") print(f"\nEvaluation complete. Results for {len(all_results)} images " f"saved to {args.results_file}") print("\n--- Averages ---") print(avg_row[column_order[1:]].to_string(index=False, float_format="%.4f")) if __name__ == "__main__": main()