RuthvikBandari commited on
Commit
a2c35ef
·
verified ·
1 Parent(s): 15d8b7a

Upload src/evaluation/metrics.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/evaluation/metrics.py +242 -0
src/evaluation/metrics.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DiaFoot.AI v2 — Segmentation Metrics.
2
+
3
+ Phase 4, Commit 20: Dice, IoU, HD95, NSD, ASSD + clinical metrics.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ from typing import Any
9
+
10
+ import numpy as np
11
+
12
+
13
+ def dice_score(pred: np.ndarray, target: np.ndarray, smooth: float = 1e-6) -> float:
14
+ """Compute Dice coefficient.
15
+
16
+ Args:
17
+ pred: Binary prediction mask (H, W).
18
+ target: Binary ground truth mask (H, W).
19
+ smooth: Smoothing to avoid division by zero.
20
+
21
+ Returns:
22
+ Dice score between 0 and 1.
23
+ """
24
+ pred_flat = pred.astype(bool).flatten()
25
+ target_flat = target.astype(bool).flatten()
26
+ intersection = (pred_flat & target_flat).sum()
27
+ return float((2.0 * intersection + smooth) / (pred_flat.sum() + target_flat.sum() + smooth))
28
+
29
+
30
+ def iou_score(pred: np.ndarray, target: np.ndarray, smooth: float = 1e-6) -> float:
31
+ """Compute Intersection over Union (Jaccard Index).
32
+
33
+ Args:
34
+ pred: Binary prediction mask (H, W).
35
+ target: Binary ground truth mask (H, W).
36
+ smooth: Smoothing factor.
37
+
38
+ Returns:
39
+ IoU score between 0 and 1.
40
+ """
41
+ pred_flat = pred.astype(bool).flatten()
42
+ target_flat = target.astype(bool).flatten()
43
+ intersection = (pred_flat & target_flat).sum()
44
+ union = (pred_flat | target_flat).sum()
45
+ return float((intersection + smooth) / (union + smooth))
46
+
47
+
48
+ def hausdorff_distance_95(pred: np.ndarray, target: np.ndarray) -> float:
49
+ """Compute 95th percentile Hausdorff Distance.
50
+
51
+ Measures the boundary quality of segmentation.
52
+
53
+ Args:
54
+ pred: Binary prediction mask (H, W).
55
+ target: Binary ground truth mask (H, W).
56
+
57
+ Returns:
58
+ HD95 in pixels. Lower is better.
59
+ """
60
+ from scipy.ndimage import distance_transform_edt
61
+
62
+ pred_bool = pred.astype(bool)
63
+ target_bool = target.astype(bool)
64
+
65
+ # Handle edge cases
66
+ if not pred_bool.any() and not target_bool.any():
67
+ return 0.0
68
+ if not pred_bool.any() or not target_bool.any():
69
+ return float(max(pred.shape))
70
+
71
+ # Distance from pred boundary to nearest target boundary
72
+ pred_boundary = pred_bool ^ _erode(pred_bool)
73
+ target_boundary = target_bool ^ _erode(target_bool)
74
+
75
+ if not pred_boundary.any() or not target_boundary.any():
76
+ return float(max(pred.shape))
77
+
78
+ dt_target = distance_transform_edt(~target_boundary)
79
+ dt_pred = distance_transform_edt(~pred_boundary)
80
+
81
+ dist_pred_to_target = dt_target[pred_boundary]
82
+ dist_target_to_pred = dt_pred[target_boundary]
83
+
84
+ all_distances = np.concatenate([dist_pred_to_target, dist_target_to_pred])
85
+ return float(np.percentile(all_distances, 95))
86
+
87
+
88
+ def _erode(mask: np.ndarray) -> np.ndarray:
89
+ """Simple erosion by 1 pixel."""
90
+ from scipy.ndimage import binary_erosion
91
+
92
+ return binary_erosion(mask, iterations=1)
93
+
94
+
95
+ def surface_dice(
96
+ pred: np.ndarray,
97
+ target: np.ndarray,
98
+ tolerance_mm: float = 2.0,
99
+ pixel_spacing: float = 1.0,
100
+ ) -> float:
101
+ """Compute Normalized Surface Dice (NSD).
102
+
103
+ Measures what fraction of boundary points are within tolerance distance.
104
+
105
+ Args:
106
+ pred: Binary prediction mask.
107
+ target: Binary ground truth mask.
108
+ tolerance_mm: Tolerance in mm.
109
+ pixel_spacing: mm per pixel.
110
+
111
+ Returns:
112
+ NSD score between 0 and 1.
113
+ """
114
+ from scipy.ndimage import distance_transform_edt
115
+
116
+ tolerance_px = tolerance_mm / pixel_spacing
117
+ pred_bool = pred.astype(bool)
118
+ target_bool = target.astype(bool)
119
+
120
+ if not pred_bool.any() and not target_bool.any():
121
+ return 1.0
122
+ if not pred_bool.any() or not target_bool.any():
123
+ return 0.0
124
+
125
+ pred_boundary = pred_bool ^ _erode(pred_bool)
126
+ target_boundary = target_bool ^ _erode(target_bool)
127
+
128
+ if not pred_boundary.any() or not target_boundary.any():
129
+ return 0.0
130
+
131
+ dt_target = distance_transform_edt(~target_boundary)
132
+ dt_pred = distance_transform_edt(~pred_boundary)
133
+
134
+ pred_within = (dt_target[pred_boundary] <= tolerance_px).sum()
135
+ target_within = (dt_pred[target_boundary] <= tolerance_px).sum()
136
+
137
+ total_boundary = pred_boundary.sum() + target_boundary.sum()
138
+ return float((pred_within + target_within) / max(1, total_boundary))
139
+
140
+
141
+ def wound_area_mm2(
142
+ mask: np.ndarray,
143
+ pixel_spacing_mm: float = 0.5,
144
+ ) -> float:
145
+ """Estimate wound area in mm squared.
146
+
147
+ Args:
148
+ mask: Binary wound mask.
149
+ pixel_spacing_mm: Physical size of one pixel in mm.
150
+
151
+ Returns:
152
+ Wound area in mm squared.
153
+ """
154
+ wound_pixels = mask.astype(bool).sum()
155
+ return float(wound_pixels * pixel_spacing_mm * pixel_spacing_mm)
156
+
157
+
158
+ def compute_segmentation_metrics(
159
+ pred: np.ndarray,
160
+ target: np.ndarray,
161
+ pixel_spacing_mm: float = 0.5,
162
+ ) -> dict[str, float]:
163
+ """Compute all segmentation metrics for a single image.
164
+
165
+ Args:
166
+ pred: Binary prediction (H, W).
167
+ target: Binary ground truth (H, W).
168
+ pixel_spacing_mm: Physical pixel size.
169
+
170
+ Returns:
171
+ Dict with all metrics.
172
+ """
173
+ metrics: dict[str, float] = {
174
+ "dice": dice_score(pred, target),
175
+ "iou": iou_score(pred, target),
176
+ }
177
+
178
+ # Only compute boundary metrics if both masks have content
179
+ if pred.astype(bool).any() and target.astype(bool).any():
180
+ metrics["hd95"] = hausdorff_distance_95(pred, target)
181
+ metrics["nsd_2mm"] = surface_dice(
182
+ pred, target, tolerance_mm=2.0, pixel_spacing=pixel_spacing_mm
183
+ )
184
+ metrics["nsd_5mm"] = surface_dice(
185
+ pred, target, tolerance_mm=5.0, pixel_spacing=pixel_spacing_mm
186
+ )
187
+ else:
188
+ metrics["hd95"] = 0.0 if not target.astype(bool).any() else float(max(pred.shape))
189
+ metrics["nsd_2mm"] = 1.0 if not target.astype(bool).any() else 0.0
190
+ metrics["nsd_5mm"] = 1.0 if not target.astype(bool).any() else 0.0
191
+
192
+ # Clinical metrics
193
+ metrics["wound_area_mm2"] = wound_area_mm2(pred, pixel_spacing_mm)
194
+ metrics["wound_area_gt_mm2"] = wound_area_mm2(target, pixel_spacing_mm)
195
+
196
+ return metrics
197
+
198
+
199
+ def aggregate_metrics(
200
+ all_metrics: list[dict[str, float]],
201
+ ) -> dict[str, Any]:
202
+ """Aggregate per-image metrics into summary statistics.
203
+
204
+ Args:
205
+ all_metrics: List of per-image metric dicts.
206
+
207
+ Returns:
208
+ Dict with mean, std, median for each metric.
209
+ """
210
+ if not all_metrics:
211
+ return {}
212
+
213
+ keys = all_metrics[0].keys()
214
+ summary: dict[str, Any] = {}
215
+
216
+ for key in keys:
217
+ values = [m[key] for m in all_metrics if key in m]
218
+ if values:
219
+ summary[key] = {
220
+ "mean": float(np.mean(values)),
221
+ "std": float(np.std(values)),
222
+ "median": float(np.median(values)),
223
+ "min": float(np.min(values)),
224
+ "max": float(np.max(values)),
225
+ }
226
+
227
+ return summary
228
+
229
+
230
+ def print_segmentation_report(summary: dict[str, Any]) -> None:
231
+ """Print formatted segmentation results."""
232
+ print(f"\n{'=' * 60}") # noqa: T201
233
+ print("Segmentation Results") # noqa: T201
234
+ print(f"{'=' * 60}") # noqa: T201
235
+ for metric, stats in summary.items():
236
+ if isinstance(stats, dict) and "mean" in stats:
237
+ print( # noqa: T201
238
+ f" {metric:20s}: {stats['mean']:.4f} "
239
+ f"(+/- {stats['std']:.4f}) "
240
+ f"[{stats['min']:.4f}, {stats['max']:.4f}]"
241
+ )
242
+ print(f"{'=' * 60}\n") # noqa: T201