File size: 4,470 Bytes
86c24cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
"""
Post-processing: structural mask filtering, cross-class NMS, threshold sweep.
"""

import numpy as np
from scipy.spatial.distance import cdist
from skimage.morphology import dilation, disk
from typing import Dict, List, Optional


def apply_structural_mask_filter(
    detections: List[dict],
    mask: np.ndarray,
    margin_px: int = 5,
) -> List[dict]:
    """
    Remove detections outside biological tissue regions.

    Args:
        detections: list of {'x', 'y', 'class', 'conf'}
        mask: boolean array (H, W) where True = tissue region
        margin_px: dilate mask by this many pixels

    Returns:
        Filtered detection list.
    """
    if mask is None:
        return detections

    # Dilate mask to allow particles at region boundaries
    tissue = dilation(mask, disk(margin_px))

    filtered = []
    for det in detections:
        xi, yi = int(round(det["x"])), int(round(det["y"]))
        if (0 <= yi < tissue.shape[0] and
            0 <= xi < tissue.shape[1] and
            tissue[yi, xi]):
            filtered.append(det)
    return filtered


def cross_class_nms(
    detections: List[dict],
    distance_threshold: float = 8.0,
) -> List[dict]:
    """
    When 6nm and 12nm detections overlap, keep the higher-confidence one.

    This handles cases where both heads fire on the same particle.
    """
    if len(detections) <= 1:
        return detections

    # Sort by confidence descending
    dets = sorted(detections, key=lambda d: d["conf"], reverse=True)
    keep = [True] * len(dets)

    coords = np.array([[d["x"], d["y"]] for d in dets])

    for i in range(len(dets)):
        if not keep[i]:
            continue
        for j in range(i + 1, len(dets)):
            if not keep[j]:
                continue
            # Only suppress across classes
            if dets[i]["class"] == dets[j]["class"]:
                continue
            dist = np.sqrt(
                (coords[i, 0] - coords[j, 0]) ** 2
                + (coords[i, 1] - coords[j, 1]) ** 2
            )
            if dist < distance_threshold:
                keep[j] = False  # Lower confidence suppressed

    return [d for d, k in zip(dets, keep) if k]


def sweep_confidence_threshold(
    detections: List[dict],
    gt_coords: Dict[str, np.ndarray],
    match_radii: Dict[str, float],
    start: float = 0.05,
    stop: float = 0.95,
    step: float = 0.01,
) -> Dict[str, float]:
    """
    Sweep confidence thresholds to find optimal per-class thresholds.

    Args:
        detections: all detections (before thresholding)
        gt_coords: {'6nm': Nx2, '12nm': Mx2} ground truth
        match_radii: per-class match radii in pixels
        start, stop, step: sweep range

    Returns:
        Dict with best threshold per class and overall.
    """
    from src.evaluate import match_detections_to_gt, compute_f1

    best_thresholds = {}
    thresholds = np.arange(start, stop, step)

    for cls in ["6nm", "12nm"]:
        best_f1 = -1
        best_thr = 0.3

        for thr in thresholds:
            cls_dets = [d for d in detections if d["class"] == cls and d["conf"] >= thr]
            if not cls_dets and len(gt_coords[cls]) == 0:
                continue

            pred_coords = np.array([[d["x"], d["y"]] for d in cls_dets]).reshape(-1, 2)
            gt = gt_coords[cls]

            if len(pred_coords) == 0:
                tp, fp, fn = 0, 0, len(gt)
            elif len(gt) == 0:
                tp, fp, fn = 0, len(pred_coords), 0
            else:
                tp, fp, fn = _simple_match(pred_coords, gt, match_radii[cls])

            f1, _, _ = compute_f1(tp, fp, fn)
            if f1 > best_f1:
                best_f1 = f1
                best_thr = thr

        best_thresholds[cls] = best_thr

    return best_thresholds


def _simple_match(
    pred: np.ndarray, gt: np.ndarray, radius: float
) -> tuple:
    """Quick matching for threshold sweep (greedy, not Hungarian)."""
    from scipy.spatial.distance import cdist

    if len(pred) == 0 or len(gt) == 0:
        return 0, len(pred), len(gt)

    dists = cdist(pred, gt)
    tp = 0
    matched_gt = set()

    # Greedy: match closest pairs first
    for i in range(len(pred)):
        min_j = np.argmin(dists[i])
        if dists[i, min_j] <= radius and min_j not in matched_gt:
            tp += 1
            matched_gt.add(min_j)
            dists[:, min_j] = np.inf

    fp = len(pred) - tp
    fn = len(gt) - tp
    return tp, fp, fn