File size: 6,235 Bytes
457db56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import argparse
import os
import sys
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
import glob
from scipy.spatial.distance import directed_hausdorff

# Add project root to path for custom module import
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from guided_diffusion.custom_lidc_dataset import CustomLIDCDataset

# --- METRIC IMPLEMENTATIONS ---

def dice_coefficient(pred, target):
    smooth = 1e-6
    pred, target = pred.astype(bool), target.astype(bool)
    if not np.any(pred) and not np.any(target): return 1.0
    intersection = np.sum(pred & target)
    return (2. * intersection + smooth) / (np.sum(pred) + np.sum(target) + smooth)

def iou_score(pred, target):
    smooth = 1e-6
    pred, target = pred.astype(bool), target.astype(bool)
    if not np.any(pred) and not np.any(target): return 1.0
    intersection = np.sum(pred & target)
    union = np.sum(pred | target)
    return (intersection + smooth) / (union + smooth)

def hausdorff_distance(pred, target):
    pred_points, target_points = np.argwhere(pred), np.argwhere(target)
    if len(pred_points) == 0 or len(target_points) == 0: return np.nan
    return max(directed_hausdorff(pred_points, target_points)[0], directed_hausdorff(target_points, pred_points)[0])

def calculate_combined_sensitivity(samples, gts):
    combined_sample = np.logical_or.reduce([s.astype(bool) for s in samples])
    combined_gt = np.logical_or.reduce([g.astype(bool) for g in gts])
    if not np.any(combined_gt): return 1.0
    tp = np.sum(combined_sample & combined_gt)
    return tp / np.sum(combined_gt)

def calculate_d_max(samples, gts):
    max_dice_scores = [np.max([dice_coefficient(s, gt) for s in samples]) if samples else 0.0 for gt in gts]
    return np.mean(max_dice_scores) if max_dice_scores else 0.0

def calculate_diversity_agreement(samples, gts):
    def get_variances(masks):
        if len(masks) < 2: return 0, 0
        scores = [1.0 - dice_coefficient(masks[i], masks[j]) for i in range(len(masks)) for j in range(i + 1, len(masks))]
        return np.min(scores) if scores else 0, np.max(scores) if scores else 0
    V_min_gt, V_max_gt = get_variances(gts)
    V_min_sample, V_max_sample = get_variances(samples)
    return 1.0 - (abs(V_min_gt - V_min_sample) + abs(V_max_gt - V_max_sample)) / 2.0

def calculate_ci_score(samples, gts):
    Sc = calculate_combined_sensitivity(samples, gts)
    Dmax = calculate_d_max(samples, gts)
    Da = calculate_diversity_agreement(samples, gts)
    denominator = Sc + Dmax + Da
    ci = (3 * Sc * Dmax * Da) / (denominator + 1e-8)
    return ci, Sc, Dmax, Da

def calculate_ged(samples, gts):
    dist_fn = lambda x, y: 1.0 - iou_score(x, y)
    d_st = np.mean([dist_fn(s, g) for s in samples for g in gts])
    d_ss = np.mean([dist_fn(samples[i], samples[j]) for i in range(len(samples)) for j in range(i + 1, len(samples))]) if len(samples) > 1 else 0
    d_gg = np.mean([dist_fn(gts[i], gts[j]) for i in range(len(gts)) for j in range(i + 1, len(gts))]) if len(gts) > 1 else 0
    return 2 * d_st - d_ss - d_gg

def load_mask(path):
    with Image.open(path) as img:
        return np.array(img.convert("L")) > 127

def main():
    parser = argparse.ArgumentParser(description="Evaluate ambiguous segmentation samples.")
    parser.add_argument("--samples_dir", type=str, required=True, help="Directory with saved sample masks.")
    parser.add_argument("--gt_data_dir", type=str, required=True, help="Path to the root of the LIDC data directory.")
    parser.add_argument("--dataset_type", type=str, default="lidc", choices=["lidc", "multiannotator"],
                        help="Ground-truth dataset layout.")
    parser.add_argument("--split_strategy", type=str, default="all_annotations",
                        help="Split strategy for multiannotator dataset.")
    parser.add_argument("--image_size", type=int, default=128, help="Image size used during training/sampling.")
    parser.add_argument("--results_file", type=str, default="evaluation_results.csv", help="Path to save the output CSV file.")
    args = parser.parse_args()

    gt_dataset = CustomLIDCDataset(
        data_root=args.gt_data_dir,
        split="test",
        image_size=args.image_size,
        dataset_type=args.dataset_type,
        split_strategy=args.split_strategy,
    )
    all_results = []
    
    print(f"Evaluating {len(gt_dataset)} images...")
    for i in tqdm(range(len(gt_dataset))):
        _, gts_tensor, image_id = gt_dataset[i]
        image_id = str(image_id)
        
        sample_paths = sorted(glob.glob(os.path.join(args.samples_dir, f"{image_id}_sample_*.png")))
        if not sample_paths:
            print(f"Warning: No samples found for {image_id}. Skipping.")
            continue
            
        samples = [load_mask(p) for p in sample_paths]
        gts = [gt.numpy() for gt in gts_tensor]

        ci_score, sc, dmax, da = calculate_ci_score(samples, gts)
        ged = calculate_ged(samples, gts)
        
        all_dice = [dice_coefficient(s, g) for s in samples for g in gts]
        all_iou = [iou_score(s, g) for s in samples for g in gts]
        all_hd = [d for d in [hausdorff_distance(s,g) for s in samples for g in gts] if not np.isnan(d)]
        
        all_results.append({
            "image_id": image_id, "CI_Score": ci_score, "Combined_Sensitivity": sc,
            "D_max": dmax, "Diversity_Agreement": da, "GED": ged,
            "Avg_Dice": np.mean(all_dice) if all_dice else 0,
            "Avg_IoU": np.mean(all_iou) if all_iou else 0,
            "Avg_Hausdorff": np.mean(all_hd) if all_hd else np.nan,
        })

    if not all_results:
        print("No results generated. Check paths and filenames.")
        return

    df = pd.DataFrame(all_results)
    avg_row = df.mean(numeric_only=True).to_frame().T
    avg_row['image_id'] = 'AVERAGE'
    df_final = pd.concat([df, avg_row], ignore_index=True)

    df_final.to_csv(args.results_file, index=False, float_format='%.4f')
    print(f"\nEvaluation complete. Results saved to {args.results_file}")
    print("\n--- Averages ---")
    print(avg_row.to_string(index=False, float_format='%.4f'))

if __name__ == "__main__":
    main()