| """Unit tests for heatmap GT generation and peak extraction.""" |
|
|
| import numpy as np |
| import pytest |
| import torch |
|
|
| from src.heatmap import generate_heatmap_gt, extract_peaks |
|
|
|
|
| class TestHeatmapGeneration: |
| def test_single_particle_peak(self): |
| """A single particle should produce a Gaussian peak at the correct location.""" |
| coords_6nm = np.array([[100.0, 200.0]]) |
| coords_12nm = np.empty((0, 2)) |
|
|
| hm, off, mask, conf = generate_heatmap_gt( |
| coords_6nm, coords_12nm, 512, 512, stride=2, |
| ) |
|
|
| assert hm.shape == (2, 256, 256) |
| assert off.shape == (2, 256, 256) |
| assert mask.shape == (256, 256) |
|
|
| |
| peak_y, peak_x = np.unravel_index(hm[0].argmax(), hm[0].shape) |
| assert abs(peak_x - 50) <= 1 |
| assert abs(peak_y - 100) <= 1 |
|
|
| |
| assert hm[0].max() == pytest.approx(1.0, abs=0.01) |
|
|
| |
| assert hm[1].max() == 0.0 |
|
|
| def test_two_classes(self): |
| """Both classes should produce peaks in their respective channels.""" |
| coords_6nm = np.array([[100.0, 100.0]]) |
| coords_12nm = np.array([[200.0, 200.0]]) |
|
|
| hm, _, _, _ = generate_heatmap_gt( |
| coords_6nm, coords_12nm, 512, 512, stride=2, |
| ) |
|
|
| assert hm[0].max() > 0.9 |
| assert hm[1].max() > 0.9 |
|
|
| def test_offset_values(self): |
| """Offsets should encode sub-pixel correction.""" |
| |
| |
| coords_6nm = np.array([[101.5, 200.5]]) |
| coords_12nm = np.empty((0, 2)) |
|
|
| _, off, mask, _ = generate_heatmap_gt( |
| coords_6nm, coords_12nm, 512, 512, stride=2, |
| ) |
|
|
| |
| assert mask.sum() == 1 |
|
|
| def test_empty_annotations(self): |
| """Empty annotations should produce zero heatmap.""" |
| hm, off, mask, conf = generate_heatmap_gt( |
| np.empty((0, 2)), np.empty((0, 2)), 512, 512, |
| ) |
| assert hm.max() == 0.0 |
| assert mask.sum() == 0 |
|
|
| def test_confidence_weighting(self): |
| """Confidence < 1 should scale peak value.""" |
| coords = np.array([[100.0, 100.0]]) |
| confidences = np.array([0.5]) |
|
|
| hm, _, _, _ = generate_heatmap_gt( |
| coords, np.empty((0, 2)), 512, 512, |
| confidence_6nm=confidences, |
| ) |
|
|
| assert hm[0].max() == pytest.approx(0.5, abs=0.05) |
|
|
| def test_overlapping_particles_use_max(self): |
| """Overlapping Gaussians should use element-wise max, not sum.""" |
| coords = np.array([[100.0, 100.0], [104.0, 100.0]]) |
| hm, _, _, _ = generate_heatmap_gt( |
| coords, np.empty((0, 2)), 512, 512, stride=2, |
| ) |
| |
| assert hm[0].max() <= 1.0 |
|
|
|
|
| class TestPeakExtraction: |
| def test_single_peak(self): |
| """Extract a single peak from synthetic heatmap.""" |
| heatmap = torch.zeros(2, 256, 256) |
| heatmap[0, 100, 50] = 0.9 |
|
|
| offset_map = torch.zeros(2, 256, 256) |
| offset_map[0, 100, 50] = 0.3 |
| offset_map[1, 100, 50] = 0.1 |
|
|
| dets = extract_peaks(heatmap, offset_map, stride=2, conf_threshold=0.5) |
|
|
| assert len(dets) == 1 |
| assert dets[0]["class"] == "6nm" |
| assert dets[0]["conf"] == pytest.approx(0.9, abs=0.01) |
| |
| assert dets[0]["x"] == pytest.approx(100.6, abs=0.1) |
| |
| assert dets[0]["y"] == pytest.approx(200.2, abs=0.1) |
|
|
| def test_nms_suppresses_neighbors(self): |
| """NMS should suppress weaker neighboring peaks.""" |
| heatmap = torch.zeros(2, 256, 256) |
| heatmap[0, 100, 50] = 0.9 |
| heatmap[0, 101, 50] = 0.7 |
|
|
| dets = extract_peaks( |
| heatmap, torch.zeros(2, 256, 256), |
| stride=2, conf_threshold=0.5, |
| nms_kernel_sizes={"6nm": 5, "12nm": 5}, |
| ) |
|
|
| |
| assert len([d for d in dets if d["class"] == "6nm"]) == 1 |
|
|
| def test_below_threshold_filtered(self): |
| """Peaks below threshold should not be extracted.""" |
| heatmap = torch.zeros(2, 256, 256) |
| heatmap[0, 100, 50] = 0.2 |
|
|
| dets = extract_peaks(heatmap, torch.zeros(2, 256, 256), conf_threshold=0.3) |
| assert len(dets) == 0 |
|
|