File size: 4,470 Bytes
86c24cb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 | """
Post-processing: structural mask filtering, cross-class NMS, threshold sweep.
"""
import numpy as np
from scipy.spatial.distance import cdist
from skimage.morphology import dilation, disk
from typing import Dict, List, Optional
def apply_structural_mask_filter(
detections: List[dict],
mask: np.ndarray,
margin_px: int = 5,
) -> List[dict]:
"""
Remove detections outside biological tissue regions.
Args:
detections: list of {'x', 'y', 'class', 'conf'}
mask: boolean array (H, W) where True = tissue region
margin_px: dilate mask by this many pixels
Returns:
Filtered detection list.
"""
if mask is None:
return detections
# Dilate mask to allow particles at region boundaries
tissue = dilation(mask, disk(margin_px))
filtered = []
for det in detections:
xi, yi = int(round(det["x"])), int(round(det["y"]))
if (0 <= yi < tissue.shape[0] and
0 <= xi < tissue.shape[1] and
tissue[yi, xi]):
filtered.append(det)
return filtered
def cross_class_nms(
detections: List[dict],
distance_threshold: float = 8.0,
) -> List[dict]:
"""
When 6nm and 12nm detections overlap, keep the higher-confidence one.
This handles cases where both heads fire on the same particle.
"""
if len(detections) <= 1:
return detections
# Sort by confidence descending
dets = sorted(detections, key=lambda d: d["conf"], reverse=True)
keep = [True] * len(dets)
coords = np.array([[d["x"], d["y"]] for d in dets])
for i in range(len(dets)):
if not keep[i]:
continue
for j in range(i + 1, len(dets)):
if not keep[j]:
continue
# Only suppress across classes
if dets[i]["class"] == dets[j]["class"]:
continue
dist = np.sqrt(
(coords[i, 0] - coords[j, 0]) ** 2
+ (coords[i, 1] - coords[j, 1]) ** 2
)
if dist < distance_threshold:
keep[j] = False # Lower confidence suppressed
return [d for d, k in zip(dets, keep) if k]
def sweep_confidence_threshold(
detections: List[dict],
gt_coords: Dict[str, np.ndarray],
match_radii: Dict[str, float],
start: float = 0.05,
stop: float = 0.95,
step: float = 0.01,
) -> Dict[str, float]:
"""
Sweep confidence thresholds to find optimal per-class thresholds.
Args:
detections: all detections (before thresholding)
gt_coords: {'6nm': Nx2, '12nm': Mx2} ground truth
match_radii: per-class match radii in pixels
start, stop, step: sweep range
Returns:
Dict with best threshold per class and overall.
"""
from src.evaluate import match_detections_to_gt, compute_f1
best_thresholds = {}
thresholds = np.arange(start, stop, step)
for cls in ["6nm", "12nm"]:
best_f1 = -1
best_thr = 0.3
for thr in thresholds:
cls_dets = [d for d in detections if d["class"] == cls and d["conf"] >= thr]
if not cls_dets and len(gt_coords[cls]) == 0:
continue
pred_coords = np.array([[d["x"], d["y"]] for d in cls_dets]).reshape(-1, 2)
gt = gt_coords[cls]
if len(pred_coords) == 0:
tp, fp, fn = 0, 0, len(gt)
elif len(gt) == 0:
tp, fp, fn = 0, len(pred_coords), 0
else:
tp, fp, fn = _simple_match(pred_coords, gt, match_radii[cls])
f1, _, _ = compute_f1(tp, fp, fn)
if f1 > best_f1:
best_f1 = f1
best_thr = thr
best_thresholds[cls] = best_thr
return best_thresholds
def _simple_match(
pred: np.ndarray, gt: np.ndarray, radius: float
) -> tuple:
"""Quick matching for threshold sweep (greedy, not Hungarian)."""
from scipy.spatial.distance import cdist
if len(pred) == 0 or len(gt) == 0:
return 0, len(pred), len(gt)
dists = cdist(pred, gt)
tp = 0
matched_gt = set()
# Greedy: match closest pairs first
for i in range(len(pred)):
min_j = np.argmin(dists[i])
if dists[i, min_j] <= radius and min_j not in matched_gt:
tp += 1
matched_gt.add(min_j)
dists[:, min_j] = np.inf
fp = len(pred) - tp
fn = len(gt) - tp
return tp, fp, fn
|