AnikS22 commited on
Commit
e52f2ac
·
verified ·
1 Parent(s): 6dd4c34

Upload src/evaluate.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/evaluate.py +203 -0
src/evaluate.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluation: Hungarian matching, per-class metrics, LOOCV runner.
3
+
4
+ Uses scipy linear_sum_assignment for optimal bipartite matching between
5
+ predictions and ground truth with class-specific match radii.
6
+ """
7
+
8
+ import numpy as np
9
+ from scipy.optimize import linear_sum_assignment
10
+ from scipy.spatial.distance import cdist
11
+ from typing import Dict, List, Optional, Tuple
12
+
13
+
14
+ def compute_f1(tp: int, fp: int, fn: int, eps: float = 1e-6) -> Tuple[float, float, float]:
15
+ """Compute F1, precision, recall from TP/FP/FN counts."""
16
+ precision = tp / (tp + fp + eps)
17
+ recall = tp / (tp + fn + eps)
18
+ f1 = 2 * precision * recall / (precision + recall + eps)
19
+ return f1, precision, recall
20
+
21
+
22
+ def match_detections_to_gt(
23
+ detections: List[dict],
24
+ gt_coords_6nm: np.ndarray,
25
+ gt_coords_12nm: np.ndarray,
26
+ match_radii: Optional[Dict[str, float]] = None,
27
+ ) -> Dict[str, dict]:
28
+ """
29
+ Hungarian matching between predictions and ground truth.
30
+
31
+ A detection matches GT only if:
32
+ 1. Euclidean distance < match_radius[class]
33
+ 2. Predicted class == GT class
34
+
35
+ Args:
36
+ detections: list of {'x', 'y', 'class', 'conf'}
37
+ gt_coords_6nm: (N, 2) array of (x, y) GT for 6nm
38
+ gt_coords_12nm: (M, 2) array of (x, y) GT for 12nm
39
+ match_radii: per-class match radius in pixels
40
+
41
+ Returns:
42
+ Dict with per-class and overall TP/FP/FN/F1/precision/recall.
43
+ """
44
+ if match_radii is None:
45
+ match_radii = {"6nm": 9.0, "12nm": 15.0}
46
+
47
+ gt_by_class = {"6nm": gt_coords_6nm, "12nm": gt_coords_12nm}
48
+ results = {}
49
+
50
+ total_tp = 0
51
+ total_fp = 0
52
+ total_fn = 0
53
+
54
+ for cls in ["6nm", "12nm"]:
55
+ cls_dets = [d for d in detections if d["class"] == cls]
56
+ gt = gt_by_class[cls]
57
+ radius = match_radii[cls]
58
+
59
+ if len(cls_dets) == 0 and len(gt) == 0:
60
+ results[cls] = {
61
+ "tp": 0, "fp": 0, "fn": 0,
62
+ "f1": 1.0, "precision": 1.0, "recall": 1.0,
63
+ }
64
+ continue
65
+
66
+ if len(cls_dets) == 0:
67
+ results[cls] = {
68
+ "tp": 0, "fp": 0, "fn": len(gt),
69
+ "f1": 0.0, "precision": 0.0, "recall": 0.0,
70
+ }
71
+ total_fn += len(gt)
72
+ continue
73
+
74
+ if len(gt) == 0:
75
+ results[cls] = {
76
+ "tp": 0, "fp": len(cls_dets), "fn": 0,
77
+ "f1": 0.0, "precision": 0.0, "recall": 0.0,
78
+ }
79
+ total_fp += len(cls_dets)
80
+ continue
81
+
82
+ # Build cost matrix
83
+ pred_coords = np.array([[d["x"], d["y"]] for d in cls_dets])
84
+ cost = cdist(pred_coords, gt)
85
+
86
+ # Set costs beyond radius to a large value (forbid match)
87
+ cost_masked = cost.copy()
88
+ cost_masked[cost_masked > radius] = 1e6
89
+
90
+ # Hungarian matching
91
+ row_ind, col_ind = linear_sum_assignment(cost_masked)
92
+
93
+ # Count valid matches (within radius)
94
+ tp = sum(
95
+ 1 for r, c in zip(row_ind, col_ind)
96
+ if cost_masked[r, c] <= radius
97
+ )
98
+ fp = len(cls_dets) - tp
99
+ fn = len(gt) - tp
100
+
101
+ f1, prec, rec = compute_f1(tp, fp, fn)
102
+
103
+ results[cls] = {
104
+ "tp": tp, "fp": fp, "fn": fn,
105
+ "f1": f1, "precision": prec, "recall": rec,
106
+ }
107
+
108
+ total_tp += tp
109
+ total_fp += fp
110
+ total_fn += fn
111
+
112
+ # Overall
113
+ f1_overall, prec_overall, rec_overall = compute_f1(total_tp, total_fp, total_fn)
114
+ results["overall"] = {
115
+ "tp": total_tp, "fp": total_fp, "fn": total_fn,
116
+ "f1": f1_overall, "precision": prec_overall, "recall": rec_overall,
117
+ }
118
+
119
+ # Mean F1 across classes
120
+ class_f1s = [results[c]["f1"] for c in ["6nm", "12nm"] if results[c]["fn"] + results[c]["tp"] > 0]
121
+ results["mean_f1"] = np.mean(class_f1s) if class_f1s else 0.0
122
+
123
+ return results
124
+
125
+
126
+ def evaluate_fold(
127
+ detections: List[dict],
128
+ gt_annotations: Dict[str, np.ndarray],
129
+ match_radii: Optional[Dict[str, float]] = None,
130
+ has_6nm: bool = True,
131
+ ) -> Dict[str, dict]:
132
+ """
133
+ Evaluate detections for a single LOOCV fold.
134
+
135
+ Args:
136
+ detections: model predictions
137
+ gt_annotations: {'6nm': Nx2, '12nm': Mx2}
138
+ match_radii: per-class match radii
139
+ has_6nm: whether this fold has 6nm GT (False for S7, S15)
140
+
141
+ Returns:
142
+ Evaluation metrics dict.
143
+ """
144
+ gt_6nm = gt_annotations.get("6nm", np.empty((0, 2)))
145
+ gt_12nm = gt_annotations.get("12nm", np.empty((0, 2)))
146
+
147
+ results = match_detections_to_gt(detections, gt_6nm, gt_12nm, match_radii)
148
+
149
+ if not has_6nm:
150
+ results["6nm"]["note"] = "N/A (missing annotations)"
151
+
152
+ return results
153
+
154
+
155
+ def compute_average_precision(
156
+ detections: List[dict],
157
+ gt_coords: np.ndarray,
158
+ match_radius: float,
159
+ ) -> float:
160
+ """
161
+ Compute Average Precision (AP) for a single class.
162
+
163
+ Follows PASCAL VOC style: sort by confidence, compute precision-recall
164
+ curve, then compute area under curve.
165
+ """
166
+ if len(gt_coords) == 0:
167
+ return 0.0 if detections else 1.0
168
+
169
+ # Sort by confidence descending
170
+ sorted_dets = sorted(detections, key=lambda d: d["conf"], reverse=True)
171
+
172
+ tp_list = []
173
+ fp_list = []
174
+ matched_gt = set()
175
+
176
+ for det in sorted_dets:
177
+ det_coord = np.array([det["x"], det["y"]])
178
+ dists = np.sqrt(np.sum((gt_coords - det_coord) ** 2, axis=1))
179
+ min_idx = np.argmin(dists)
180
+
181
+ if dists[min_idx] <= match_radius and min_idx not in matched_gt:
182
+ tp_list.append(1)
183
+ fp_list.append(0)
184
+ matched_gt.add(min_idx)
185
+ else:
186
+ tp_list.append(0)
187
+ fp_list.append(1)
188
+
189
+ tp_cumsum = np.cumsum(tp_list)
190
+ fp_cumsum = np.cumsum(fp_list)
191
+
192
+ precision = tp_cumsum / (tp_cumsum + fp_cumsum)
193
+ recall = tp_cumsum / len(gt_coords)
194
+
195
+ # Compute AP using all-point interpolation
196
+ ap = 0.0
197
+ for i in range(len(precision)):
198
+ if i == 0:
199
+ ap += precision[i] * recall[i]
200
+ else:
201
+ ap += precision[i] * (recall[i] - recall[i - 1])
202
+
203
+ return ap