MidasMap / tests /test_heatmap.py
AnikS22's picture
Sync all source code, docs, and configs
1fc7794 verified
"""Unit tests for heatmap GT generation and peak extraction."""
import numpy as np
import pytest
import torch
from src.heatmap import generate_heatmap_gt, extract_peaks
class TestHeatmapGeneration:
def test_single_particle_peak(self):
"""A single particle should produce a Gaussian peak at the correct location."""
coords_6nm = np.array([[100.0, 200.0]])
coords_12nm = np.empty((0, 2))
hm, off, mask, conf = generate_heatmap_gt(
coords_6nm, coords_12nm, 512, 512, stride=2,
)
assert hm.shape == (2, 256, 256)
assert off.shape == (2, 256, 256)
assert mask.shape == (256, 256)
# Peak should be at (50, 100) in stride-2 space
peak_y, peak_x = np.unravel_index(hm[0].argmax(), hm[0].shape)
assert abs(peak_x - 50) <= 1
assert abs(peak_y - 100) <= 1
# Peak value should be 1.0 (confidence=1.0 default)
assert hm[0].max() == pytest.approx(1.0, abs=0.01)
# 12nm channel should be empty
assert hm[1].max() == 0.0
def test_two_classes(self):
"""Both classes should produce peaks in their respective channels."""
coords_6nm = np.array([[100.0, 100.0]])
coords_12nm = np.array([[200.0, 200.0]])
hm, _, _, _ = generate_heatmap_gt(
coords_6nm, coords_12nm, 512, 512, stride=2,
)
assert hm[0].max() > 0.9 # 6nm channel has peak
assert hm[1].max() > 0.9 # 12nm channel has peak
def test_offset_values(self):
"""Offsets should encode sub-pixel correction."""
# Place particle at (101.5, 200.5) → stride-2 center at (50.75, 100.25)
# Integer center: (51, 100) → offset: (-0.25, 0.25)
coords_6nm = np.array([[101.5, 200.5]])
coords_12nm = np.empty((0, 2))
_, off, mask, _ = generate_heatmap_gt(
coords_6nm, coords_12nm, 512, 512, stride=2,
)
# Mask should have exactly one True pixel
assert mask.sum() == 1
def test_empty_annotations(self):
"""Empty annotations should produce zero heatmap."""
hm, off, mask, conf = generate_heatmap_gt(
np.empty((0, 2)), np.empty((0, 2)), 512, 512,
)
assert hm.max() == 0.0
assert mask.sum() == 0
def test_confidence_weighting(self):
"""Confidence < 1 should scale peak value."""
coords = np.array([[100.0, 100.0]])
confidences = np.array([0.5])
hm, _, _, _ = generate_heatmap_gt(
coords, np.empty((0, 2)), 512, 512,
confidence_6nm=confidences,
)
assert hm[0].max() == pytest.approx(0.5, abs=0.05)
def test_overlapping_particles_use_max(self):
"""Overlapping Gaussians should use element-wise max, not sum."""
coords = np.array([[100.0, 100.0], [104.0, 100.0]]) # close together
hm, _, _, _ = generate_heatmap_gt(
coords, np.empty((0, 2)), 512, 512, stride=2,
)
# Max should be 1.0, not >1.0
assert hm[0].max() <= 1.0
class TestPeakExtraction:
def test_single_peak(self):
"""Extract a single peak from synthetic heatmap."""
heatmap = torch.zeros(2, 256, 256)
heatmap[0, 100, 50] = 0.9 # 6nm peak
offset_map = torch.zeros(2, 256, 256)
offset_map[0, 100, 50] = 0.3 # dx
offset_map[1, 100, 50] = 0.1 # dy
dets = extract_peaks(heatmap, offset_map, stride=2, conf_threshold=0.5)
assert len(dets) == 1
assert dets[0]["class"] == "6nm"
assert dets[0]["conf"] == pytest.approx(0.9, abs=0.01)
# x = (50 + 0.3) * 2 = 100.6
assert dets[0]["x"] == pytest.approx(100.6, abs=0.1)
# y = (100 + 0.1) * 2 = 200.2
assert dets[0]["y"] == pytest.approx(200.2, abs=0.1)
def test_nms_suppresses_neighbors(self):
"""NMS should suppress weaker neighboring peaks."""
heatmap = torch.zeros(2, 256, 256)
heatmap[0, 100, 50] = 0.9 # strong
heatmap[0, 101, 50] = 0.7 # weaker neighbor (within NMS kernel)
dets = extract_peaks(
heatmap, torch.zeros(2, 256, 256),
stride=2, conf_threshold=0.5,
nms_kernel_sizes={"6nm": 5, "12nm": 5},
)
# Only the stronger peak should survive
assert len([d for d in dets if d["class"] == "6nm"]) == 1
def test_below_threshold_filtered(self):
"""Peaks below threshold should not be extracted."""
heatmap = torch.zeros(2, 256, 256)
heatmap[0, 100, 50] = 0.2 # below 0.3 threshold
dets = extract_peaks(heatmap, torch.zeros(2, 256, 256), conf_threshold=0.3)
assert len(dets) == 0