AnikS22 commited on
Commit
72f2556
·
verified ·
1 Parent(s): 60eb8cf

Upload src/postprocess.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/postprocess.py +157 -0
src/postprocess.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Post-processing: structural mask filtering, cross-class NMS, threshold sweep.
3
+ """
4
+
5
+ import numpy as np
6
+ from scipy.spatial.distance import cdist
7
+ from skimage.morphology import dilation, disk
8
+ from typing import Dict, List, Optional
9
+
10
+
11
+ def apply_structural_mask_filter(
12
+ detections: List[dict],
13
+ mask: np.ndarray,
14
+ margin_px: int = 5,
15
+ ) -> List[dict]:
16
+ """
17
+ Remove detections outside biological tissue regions.
18
+
19
+ Args:
20
+ detections: list of {'x', 'y', 'class', 'conf'}
21
+ mask: boolean array (H, W) where True = tissue region
22
+ margin_px: dilate mask by this many pixels
23
+
24
+ Returns:
25
+ Filtered detection list.
26
+ """
27
+ if mask is None:
28
+ return detections
29
+
30
+ # Dilate mask to allow particles at region boundaries
31
+ tissue = dilation(mask, disk(margin_px))
32
+
33
+ filtered = []
34
+ for det in detections:
35
+ xi, yi = int(round(det["x"])), int(round(det["y"]))
36
+ if (0 <= yi < tissue.shape[0] and
37
+ 0 <= xi < tissue.shape[1] and
38
+ tissue[yi, xi]):
39
+ filtered.append(det)
40
+ return filtered
41
+
42
+
43
+ def cross_class_nms(
44
+ detections: List[dict],
45
+ distance_threshold: float = 8.0,
46
+ ) -> List[dict]:
47
+ """
48
+ When 6nm and 12nm detections overlap, keep the higher-confidence one.
49
+
50
+ This handles cases where both heads fire on the same particle.
51
+ """
52
+ if len(detections) <= 1:
53
+ return detections
54
+
55
+ # Sort by confidence descending
56
+ dets = sorted(detections, key=lambda d: d["conf"], reverse=True)
57
+ keep = [True] * len(dets)
58
+
59
+ coords = np.array([[d["x"], d["y"]] for d in dets])
60
+
61
+ for i in range(len(dets)):
62
+ if not keep[i]:
63
+ continue
64
+ for j in range(i + 1, len(dets)):
65
+ if not keep[j]:
66
+ continue
67
+ # Only suppress across classes
68
+ if dets[i]["class"] == dets[j]["class"]:
69
+ continue
70
+ dist = np.sqrt(
71
+ (coords[i, 0] - coords[j, 0]) ** 2
72
+ + (coords[i, 1] - coords[j, 1]) ** 2
73
+ )
74
+ if dist < distance_threshold:
75
+ keep[j] = False # Lower confidence suppressed
76
+
77
+ return [d for d, k in zip(dets, keep) if k]
78
+
79
+
80
+ def sweep_confidence_threshold(
81
+ detections: List[dict],
82
+ gt_coords: Dict[str, np.ndarray],
83
+ match_radii: Dict[str, float],
84
+ start: float = 0.05,
85
+ stop: float = 0.95,
86
+ step: float = 0.01,
87
+ ) -> Dict[str, float]:
88
+ """
89
+ Sweep confidence thresholds to find optimal per-class thresholds.
90
+
91
+ Args:
92
+ detections: all detections (before thresholding)
93
+ gt_coords: {'6nm': Nx2, '12nm': Mx2} ground truth
94
+ match_radii: per-class match radii in pixels
95
+ start, stop, step: sweep range
96
+
97
+ Returns:
98
+ Dict with best threshold per class and overall.
99
+ """
100
+ from src.evaluate import match_detections_to_gt, compute_f1
101
+
102
+ best_thresholds = {}
103
+ thresholds = np.arange(start, stop, step)
104
+
105
+ for cls in ["6nm", "12nm"]:
106
+ best_f1 = -1
107
+ best_thr = 0.3
108
+
109
+ for thr in thresholds:
110
+ cls_dets = [d for d in detections if d["class"] == cls and d["conf"] >= thr]
111
+ if not cls_dets and len(gt_coords[cls]) == 0:
112
+ continue
113
+
114
+ pred_coords = np.array([[d["x"], d["y"]] for d in cls_dets]).reshape(-1, 2)
115
+ gt = gt_coords[cls]
116
+
117
+ if len(pred_coords) == 0:
118
+ tp, fp, fn = 0, 0, len(gt)
119
+ elif len(gt) == 0:
120
+ tp, fp, fn = 0, len(pred_coords), 0
121
+ else:
122
+ tp, fp, fn = _simple_match(pred_coords, gt, match_radii[cls])
123
+
124
+ f1, _, _ = compute_f1(tp, fp, fn)
125
+ if f1 > best_f1:
126
+ best_f1 = f1
127
+ best_thr = thr
128
+
129
+ best_thresholds[cls] = best_thr
130
+
131
+ return best_thresholds
132
+
133
+
134
+ def _simple_match(
135
+ pred: np.ndarray, gt: np.ndarray, radius: float
136
+ ) -> tuple:
137
+ """Quick matching for threshold sweep (greedy, not Hungarian)."""
138
+ from scipy.spatial.distance import cdist
139
+
140
+ if len(pred) == 0 or len(gt) == 0:
141
+ return 0, len(pred), len(gt)
142
+
143
+ dists = cdist(pred, gt)
144
+ tp = 0
145
+ matched_gt = set()
146
+
147
+ # Greedy: match closest pairs first
148
+ for i in range(len(pred)):
149
+ min_j = np.argmin(dists[i])
150
+ if dists[i, min_j] <= radius and min_j not in matched_gt:
151
+ tp += 1
152
+ matched_gt.add(min_j)
153
+ dists[:, min_j] = np.inf
154
+
155
+ fp = len(pred) - tp
156
+ fn = len(gt) - tp
157
+ return tp, fp, fn