MidasMap / src /heatmap.py
AnikS22's picture
Deploy MidasMap Gradio app, src, requirements, checkpoint
86c24cb verified
"""
Ground truth heatmap generation and peak extraction for CenterNet.
Generates Gaussian-splat heatmaps at stride-2 resolution with
class-specific sigma values calibrated to bead size.
"""
import numpy as np
import torch
import torch.nn.functional as F
from typing import Dict, List, Tuple, Optional
# Class index mapping
CLASS_IDX = {"6nm": 0, "12nm": 1}
CLASS_NAMES = ["6nm", "12nm"]
STRIDE = 2
def generate_heatmap_gt(
coords_6nm: np.ndarray,
coords_12nm: np.ndarray,
image_h: int,
image_w: int,
sigmas: Optional[Dict[str, float]] = None,
stride: int = STRIDE,
confidence_6nm: Optional[np.ndarray] = None,
confidence_12nm: Optional[np.ndarray] = None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""
Generate CenterNet ground truth heatmaps and offset maps.
Args:
coords_6nm: (N, 2) array of (x, y) in ORIGINAL pixel space
coords_12nm: (M, 2) array of (x, y) in ORIGINAL pixel space
image_h: original image height
image_w: original image width
sigmas: per-class Gaussian sigma in feature space
stride: output stride (default 2)
confidence_6nm: optional per-particle confidence weights
confidence_12nm: optional per-particle confidence weights
Returns:
heatmap: (2, H//stride, W//stride) float32 in [0, 1]
offsets: (2, H//stride, W//stride) float32 sub-pixel offsets
offset_mask: (H//stride, W//stride) bool — True at particle centers
conf_map: (2, H//stride, W//stride) float32 confidence weights
"""
if sigmas is None:
sigmas = {"6nm": 1.0, "12nm": 1.5}
h_feat = image_h // stride
w_feat = image_w // stride
heatmap = np.zeros((2, h_feat, w_feat), dtype=np.float32)
offsets = np.zeros((2, h_feat, w_feat), dtype=np.float32)
offset_mask = np.zeros((h_feat, w_feat), dtype=bool)
conf_map = np.ones((2, h_feat, w_feat), dtype=np.float32)
# Prepare coordinate lists with class labels and confidences
all_entries = []
if len(coords_6nm) > 0:
confs = confidence_6nm if confidence_6nm is not None else np.ones(len(coords_6nm))
for i, (x, y) in enumerate(coords_6nm):
all_entries.append((x, y, "6nm", confs[i]))
if len(coords_12nm) > 0:
confs = confidence_12nm if confidence_12nm is not None else np.ones(len(coords_12nm))
for i, (x, y) in enumerate(coords_12nm):
all_entries.append((x, y, "12nm", confs[i]))
for x, y, cls, conf in all_entries:
cidx = CLASS_IDX[cls]
sigma = sigmas[cls]
# Feature-space center (float)
cx_f = x / stride
cy_f = y / stride
# Integer grid center
cx_i = int(round(cx_f))
cy_i = int(round(cy_f))
# Sub-pixel offset
off_x = cx_f - cx_i
off_y = cy_f - cy_i
# Gaussian radius: truncate at 3 sigma
r = max(int(3 * sigma + 1), 2)
# Bounds-clipped grid
y0 = max(0, cy_i - r)
y1 = min(h_feat, cy_i + r + 1)
x0 = max(0, cx_i - r)
x1 = min(w_feat, cx_i + r + 1)
if y0 >= y1 or x0 >= x1:
continue
yy, xx = np.meshgrid(
np.arange(y0, y1),
np.arange(x0, x1),
indexing="ij",
)
# Gaussian centered at INTEGER center (not float)
# The integer center MUST be exactly 1.0 — the CornerNet focal loss
# uses pos_mask = (gt == 1.0) and treats everything else as negative.
# Centering the Gaussian at the float position produces peaks of 0.78-0.93
# which the loss sees as negatives → zero positive training signal.
gaussian = np.exp(
-((xx - cx_i) ** 2 + (yy - cy_i) ** 2) / (2 * sigma ** 2)
)
# Scale by confidence (for pseudo-label weighting)
gaussian = gaussian * conf
# Element-wise max (handles overlapping particles correctly)
heatmap[cidx, y0:y1, x0:x1] = np.maximum(
heatmap[cidx, y0:y1, x0:x1], gaussian
)
# Offset and confidence only at the integer center pixel
if 0 <= cy_i < h_feat and 0 <= cx_i < w_feat:
offsets[0, cy_i, cx_i] = off_x
offsets[1, cy_i, cx_i] = off_y
offset_mask[cy_i, cx_i] = True
conf_map[cidx, cy_i, cx_i] = conf
return heatmap, offsets, offset_mask, conf_map
def extract_peaks(
heatmap: torch.Tensor,
offset_map: torch.Tensor,
stride: int = STRIDE,
conf_threshold: float = 0.3,
nms_kernel_sizes: Optional[Dict[str, int]] = None,
) -> List[dict]:
"""
Extract detections from predicted heatmap via max-pool NMS.
Args:
heatmap: (2, H/stride, W/stride) sigmoid-activated
offset_map: (2, H/stride, W/stride) raw offset predictions
stride: output stride
conf_threshold: minimum confidence to keep
nms_kernel_sizes: per-class NMS kernel sizes
Returns:
List of {'x': float, 'y': float, 'class': str, 'conf': float}
"""
if nms_kernel_sizes is None:
nms_kernel_sizes = {"6nm": 3, "12nm": 5}
detections = []
for cls_idx, cls_name in enumerate(CLASS_NAMES):
hm_cls = heatmap[cls_idx].unsqueeze(0).unsqueeze(0) # (1,1,H,W)
kernel = nms_kernel_sizes[cls_name]
# Max-pool NMS
hmax = F.max_pool2d(
hm_cls, kernel_size=kernel, stride=1, padding=kernel // 2
)
peaks = (hmax.squeeze() == heatmap[cls_idx]) & (
heatmap[cls_idx] > conf_threshold
)
ys, xs = torch.where(peaks)
for y_idx, x_idx in zip(ys, xs):
y_i = y_idx.item()
x_i = x_idx.item()
conf = heatmap[cls_idx, y_i, x_i].item()
dx = offset_map[0, y_i, x_i].item()
dy = offset_map[1, y_i, x_i].item()
# Back to input space with sub-pixel offset
det_x = (x_i + dx) * stride
det_y = (y_i + dy) * stride
detections.append({
"x": det_x,
"y": det_y,
"class": cls_name,
"conf": conf,
})
return detections