File size: 6,012 Bytes
86c24cb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 | """
Evaluation: Hungarian matching, per-class metrics, LOOCV runner.
Uses scipy linear_sum_assignment for optimal bipartite matching between
predictions and ground truth with class-specific match radii.
"""
import numpy as np
from scipy.optimize import linear_sum_assignment
from scipy.spatial.distance import cdist
from typing import Dict, List, Optional, Tuple
def compute_f1(tp: int, fp: int, fn: int, eps: float = 1e-6) -> Tuple[float, float, float]:
"""Compute F1, precision, recall from TP/FP/FN counts."""
precision = tp / (tp + fp + eps)
recall = tp / (tp + fn + eps)
f1 = 2 * precision * recall / (precision + recall + eps)
return f1, precision, recall
def match_detections_to_gt(
detections: List[dict],
gt_coords_6nm: np.ndarray,
gt_coords_12nm: np.ndarray,
match_radii: Optional[Dict[str, float]] = None,
) -> Dict[str, dict]:
"""
Hungarian matching between predictions and ground truth.
A detection matches GT only if:
1. Euclidean distance < match_radius[class]
2. Predicted class == GT class
Args:
detections: list of {'x', 'y', 'class', 'conf'}
gt_coords_6nm: (N, 2) array of (x, y) GT for 6nm
gt_coords_12nm: (M, 2) array of (x, y) GT for 12nm
match_radii: per-class match radius in pixels
Returns:
Dict with per-class and overall TP/FP/FN/F1/precision/recall.
"""
if match_radii is None:
match_radii = {"6nm": 9.0, "12nm": 15.0}
gt_by_class = {"6nm": gt_coords_6nm, "12nm": gt_coords_12nm}
results = {}
total_tp = 0
total_fp = 0
total_fn = 0
for cls in ["6nm", "12nm"]:
cls_dets = [d for d in detections if d["class"] == cls]
gt = gt_by_class[cls]
radius = match_radii[cls]
if len(cls_dets) == 0 and len(gt) == 0:
results[cls] = {
"tp": 0, "fp": 0, "fn": 0,
"f1": 1.0, "precision": 1.0, "recall": 1.0,
}
continue
if len(cls_dets) == 0:
results[cls] = {
"tp": 0, "fp": 0, "fn": len(gt),
"f1": 0.0, "precision": 0.0, "recall": 0.0,
}
total_fn += len(gt)
continue
if len(gt) == 0:
results[cls] = {
"tp": 0, "fp": len(cls_dets), "fn": 0,
"f1": 0.0, "precision": 0.0, "recall": 0.0,
}
total_fp += len(cls_dets)
continue
# Build cost matrix
pred_coords = np.array([[d["x"], d["y"]] for d in cls_dets])
cost = cdist(pred_coords, gt)
# Set costs beyond radius to a large value (forbid match)
cost_masked = cost.copy()
cost_masked[cost_masked > radius] = 1e6
# Hungarian matching
row_ind, col_ind = linear_sum_assignment(cost_masked)
# Count valid matches (within radius)
tp = sum(
1 for r, c in zip(row_ind, col_ind)
if cost_masked[r, c] <= radius
)
fp = len(cls_dets) - tp
fn = len(gt) - tp
f1, prec, rec = compute_f1(tp, fp, fn)
results[cls] = {
"tp": tp, "fp": fp, "fn": fn,
"f1": f1, "precision": prec, "recall": rec,
}
total_tp += tp
total_fp += fp
total_fn += fn
# Overall
f1_overall, prec_overall, rec_overall = compute_f1(total_tp, total_fp, total_fn)
results["overall"] = {
"tp": total_tp, "fp": total_fp, "fn": total_fn,
"f1": f1_overall, "precision": prec_overall, "recall": rec_overall,
}
# Mean F1 across classes
class_f1s = [results[c]["f1"] for c in ["6nm", "12nm"] if results[c]["fn"] + results[c]["tp"] > 0]
results["mean_f1"] = np.mean(class_f1s) if class_f1s else 0.0
return results
def evaluate_fold(
detections: List[dict],
gt_annotations: Dict[str, np.ndarray],
match_radii: Optional[Dict[str, float]] = None,
has_6nm: bool = True,
) -> Dict[str, dict]:
"""
Evaluate detections for a single LOOCV fold.
Args:
detections: model predictions
gt_annotations: {'6nm': Nx2, '12nm': Mx2}
match_radii: per-class match radii
has_6nm: whether this fold has 6nm GT (False for S7, S15)
Returns:
Evaluation metrics dict.
"""
gt_6nm = gt_annotations.get("6nm", np.empty((0, 2)))
gt_12nm = gt_annotations.get("12nm", np.empty((0, 2)))
results = match_detections_to_gt(detections, gt_6nm, gt_12nm, match_radii)
if not has_6nm:
results["6nm"]["note"] = "N/A (missing annotations)"
return results
def compute_average_precision(
detections: List[dict],
gt_coords: np.ndarray,
match_radius: float,
) -> float:
"""
Compute Average Precision (AP) for a single class.
Follows PASCAL VOC style: sort by confidence, compute precision-recall
curve, then compute area under curve.
"""
if len(gt_coords) == 0:
return 0.0 if detections else 1.0
# Sort by confidence descending
sorted_dets = sorted(detections, key=lambda d: d["conf"], reverse=True)
tp_list = []
fp_list = []
matched_gt = set()
for det in sorted_dets:
det_coord = np.array([det["x"], det["y"]])
dists = np.sqrt(np.sum((gt_coords - det_coord) ** 2, axis=1))
min_idx = np.argmin(dists)
if dists[min_idx] <= match_radius and min_idx not in matched_gt:
tp_list.append(1)
fp_list.append(0)
matched_gt.add(min_idx)
else:
tp_list.append(0)
fp_list.append(1)
tp_cumsum = np.cumsum(tp_list)
fp_cumsum = np.cumsum(fp_list)
precision = tp_cumsum / (tp_cumsum + fp_cumsum)
recall = tp_cumsum / len(gt_coords)
# Compute AP using all-point interpolation
ap = 0.0
for i in range(len(precision)):
if i == 0:
ap += precision[i] * recall[i]
else:
ap += precision[i] * (recall[i] - recall[i - 1])
return ap
|