File size: 4,701 Bytes
1fc7794
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""Unit tests for heatmap GT generation and peak extraction."""

import numpy as np
import pytest
import torch

from src.heatmap import generate_heatmap_gt, extract_peaks


class TestHeatmapGeneration:
    def test_single_particle_peak(self):
        """A single particle should produce a Gaussian peak at the correct location."""
        coords_6nm = np.array([[100.0, 200.0]])
        coords_12nm = np.empty((0, 2))

        hm, off, mask, conf = generate_heatmap_gt(
            coords_6nm, coords_12nm, 512, 512, stride=2,
        )

        assert hm.shape == (2, 256, 256)
        assert off.shape == (2, 256, 256)
        assert mask.shape == (256, 256)

        # Peak should be at (50, 100) in stride-2 space
        peak_y, peak_x = np.unravel_index(hm[0].argmax(), hm[0].shape)
        assert abs(peak_x - 50) <= 1
        assert abs(peak_y - 100) <= 1

        # Peak value should be 1.0 (confidence=1.0 default)
        assert hm[0].max() == pytest.approx(1.0, abs=0.01)

        # 12nm channel should be empty
        assert hm[1].max() == 0.0

    def test_two_classes(self):
        """Both classes should produce peaks in their respective channels."""
        coords_6nm = np.array([[100.0, 100.0]])
        coords_12nm = np.array([[200.0, 200.0]])

        hm, _, _, _ = generate_heatmap_gt(
            coords_6nm, coords_12nm, 512, 512, stride=2,
        )

        assert hm[0].max() > 0.9  # 6nm channel has peak
        assert hm[1].max() > 0.9  # 12nm channel has peak

    def test_offset_values(self):
        """Offsets should encode sub-pixel correction."""
        # Place particle at (101.5, 200.5) → stride-2 center at (50.75, 100.25)
        # Integer center: (51, 100) → offset: (-0.25, 0.25)
        coords_6nm = np.array([[101.5, 200.5]])
        coords_12nm = np.empty((0, 2))

        _, off, mask, _ = generate_heatmap_gt(
            coords_6nm, coords_12nm, 512, 512, stride=2,
        )

        # Mask should have exactly one True pixel
        assert mask.sum() == 1

    def test_empty_annotations(self):
        """Empty annotations should produce zero heatmap."""
        hm, off, mask, conf = generate_heatmap_gt(
            np.empty((0, 2)), np.empty((0, 2)), 512, 512,
        )
        assert hm.max() == 0.0
        assert mask.sum() == 0

    def test_confidence_weighting(self):
        """Confidence < 1 should scale peak value."""
        coords = np.array([[100.0, 100.0]])
        confidences = np.array([0.5])

        hm, _, _, _ = generate_heatmap_gt(
            coords, np.empty((0, 2)), 512, 512,
            confidence_6nm=confidences,
        )

        assert hm[0].max() == pytest.approx(0.5, abs=0.05)

    def test_overlapping_particles_use_max(self):
        """Overlapping Gaussians should use element-wise max, not sum."""
        coords = np.array([[100.0, 100.0], [104.0, 100.0]])  # close together
        hm, _, _, _ = generate_heatmap_gt(
            coords, np.empty((0, 2)), 512, 512, stride=2,
        )
        # Max should be 1.0, not >1.0
        assert hm[0].max() <= 1.0


class TestPeakExtraction:
    def test_single_peak(self):
        """Extract a single peak from synthetic heatmap."""
        heatmap = torch.zeros(2, 256, 256)
        heatmap[0, 100, 50] = 0.9  # 6nm peak

        offset_map = torch.zeros(2, 256, 256)
        offset_map[0, 100, 50] = 0.3  # dx
        offset_map[1, 100, 50] = 0.1  # dy

        dets = extract_peaks(heatmap, offset_map, stride=2, conf_threshold=0.5)

        assert len(dets) == 1
        assert dets[0]["class"] == "6nm"
        assert dets[0]["conf"] == pytest.approx(0.9, abs=0.01)
        # x = (50 + 0.3) * 2 = 100.6
        assert dets[0]["x"] == pytest.approx(100.6, abs=0.1)
        # y = (100 + 0.1) * 2 = 200.2
        assert dets[0]["y"] == pytest.approx(200.2, abs=0.1)

    def test_nms_suppresses_neighbors(self):
        """NMS should suppress weaker neighboring peaks."""
        heatmap = torch.zeros(2, 256, 256)
        heatmap[0, 100, 50] = 0.9  # strong
        heatmap[0, 101, 50] = 0.7  # weaker neighbor (within NMS kernel)

        dets = extract_peaks(
            heatmap, torch.zeros(2, 256, 256),
            stride=2, conf_threshold=0.5,
            nms_kernel_sizes={"6nm": 5, "12nm": 5},
        )

        # Only the stronger peak should survive
        assert len([d for d in dets if d["class"] == "6nm"]) == 1

    def test_below_threshold_filtered(self):
        """Peaks below threshold should not be extracted."""
        heatmap = torch.zeros(2, 256, 256)
        heatmap[0, 100, 50] = 0.2  # below 0.3 threshold

        dets = extract_peaks(heatmap, torch.zeros(2, 256, 256), conf_threshold=0.3)
        assert len(dets) == 0