File size: 6,188 Bytes
86c24cb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 | """
Ground truth heatmap generation and peak extraction for CenterNet.
Generates Gaussian-splat heatmaps at stride-2 resolution with
class-specific sigma values calibrated to bead size.
"""
import numpy as np
import torch
import torch.nn.functional as F
from typing import Dict, List, Tuple, Optional
# Class index mapping
CLASS_IDX = {"6nm": 0, "12nm": 1}
CLASS_NAMES = ["6nm", "12nm"]
STRIDE = 2
def generate_heatmap_gt(
coords_6nm: np.ndarray,
coords_12nm: np.ndarray,
image_h: int,
image_w: int,
sigmas: Optional[Dict[str, float]] = None,
stride: int = STRIDE,
confidence_6nm: Optional[np.ndarray] = None,
confidence_12nm: Optional[np.ndarray] = None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""
Generate CenterNet ground truth heatmaps and offset maps.
Args:
coords_6nm: (N, 2) array of (x, y) in ORIGINAL pixel space
coords_12nm: (M, 2) array of (x, y) in ORIGINAL pixel space
image_h: original image height
image_w: original image width
sigmas: per-class Gaussian sigma in feature space
stride: output stride (default 2)
confidence_6nm: optional per-particle confidence weights
confidence_12nm: optional per-particle confidence weights
Returns:
heatmap: (2, H//stride, W//stride) float32 in [0, 1]
offsets: (2, H//stride, W//stride) float32 sub-pixel offsets
offset_mask: (H//stride, W//stride) bool — True at particle centers
conf_map: (2, H//stride, W//stride) float32 confidence weights
"""
if sigmas is None:
sigmas = {"6nm": 1.0, "12nm": 1.5}
h_feat = image_h // stride
w_feat = image_w // stride
heatmap = np.zeros((2, h_feat, w_feat), dtype=np.float32)
offsets = np.zeros((2, h_feat, w_feat), dtype=np.float32)
offset_mask = np.zeros((h_feat, w_feat), dtype=bool)
conf_map = np.ones((2, h_feat, w_feat), dtype=np.float32)
# Prepare coordinate lists with class labels and confidences
all_entries = []
if len(coords_6nm) > 0:
confs = confidence_6nm if confidence_6nm is not None else np.ones(len(coords_6nm))
for i, (x, y) in enumerate(coords_6nm):
all_entries.append((x, y, "6nm", confs[i]))
if len(coords_12nm) > 0:
confs = confidence_12nm if confidence_12nm is not None else np.ones(len(coords_12nm))
for i, (x, y) in enumerate(coords_12nm):
all_entries.append((x, y, "12nm", confs[i]))
for x, y, cls, conf in all_entries:
cidx = CLASS_IDX[cls]
sigma = sigmas[cls]
# Feature-space center (float)
cx_f = x / stride
cy_f = y / stride
# Integer grid center
cx_i = int(round(cx_f))
cy_i = int(round(cy_f))
# Sub-pixel offset
off_x = cx_f - cx_i
off_y = cy_f - cy_i
# Gaussian radius: truncate at 3 sigma
r = max(int(3 * sigma + 1), 2)
# Bounds-clipped grid
y0 = max(0, cy_i - r)
y1 = min(h_feat, cy_i + r + 1)
x0 = max(0, cx_i - r)
x1 = min(w_feat, cx_i + r + 1)
if y0 >= y1 or x0 >= x1:
continue
yy, xx = np.meshgrid(
np.arange(y0, y1),
np.arange(x0, x1),
indexing="ij",
)
# Gaussian centered at INTEGER center (not float)
# The integer center MUST be exactly 1.0 — the CornerNet focal loss
# uses pos_mask = (gt == 1.0) and treats everything else as negative.
# Centering the Gaussian at the float position produces peaks of 0.78-0.93
# which the loss sees as negatives → zero positive training signal.
gaussian = np.exp(
-((xx - cx_i) ** 2 + (yy - cy_i) ** 2) / (2 * sigma ** 2)
)
# Scale by confidence (for pseudo-label weighting)
gaussian = gaussian * conf
# Element-wise max (handles overlapping particles correctly)
heatmap[cidx, y0:y1, x0:x1] = np.maximum(
heatmap[cidx, y0:y1, x0:x1], gaussian
)
# Offset and confidence only at the integer center pixel
if 0 <= cy_i < h_feat and 0 <= cx_i < w_feat:
offsets[0, cy_i, cx_i] = off_x
offsets[1, cy_i, cx_i] = off_y
offset_mask[cy_i, cx_i] = True
conf_map[cidx, cy_i, cx_i] = conf
return heatmap, offsets, offset_mask, conf_map
def extract_peaks(
heatmap: torch.Tensor,
offset_map: torch.Tensor,
stride: int = STRIDE,
conf_threshold: float = 0.3,
nms_kernel_sizes: Optional[Dict[str, int]] = None,
) -> List[dict]:
"""
Extract detections from predicted heatmap via max-pool NMS.
Args:
heatmap: (2, H/stride, W/stride) sigmoid-activated
offset_map: (2, H/stride, W/stride) raw offset predictions
stride: output stride
conf_threshold: minimum confidence to keep
nms_kernel_sizes: per-class NMS kernel sizes
Returns:
List of {'x': float, 'y': float, 'class': str, 'conf': float}
"""
if nms_kernel_sizes is None:
nms_kernel_sizes = {"6nm": 3, "12nm": 5}
detections = []
for cls_idx, cls_name in enumerate(CLASS_NAMES):
hm_cls = heatmap[cls_idx].unsqueeze(0).unsqueeze(0) # (1,1,H,W)
kernel = nms_kernel_sizes[cls_name]
# Max-pool NMS
hmax = F.max_pool2d(
hm_cls, kernel_size=kernel, stride=1, padding=kernel // 2
)
peaks = (hmax.squeeze() == heatmap[cls_idx]) & (
heatmap[cls_idx] > conf_threshold
)
ys, xs = torch.where(peaks)
for y_idx, x_idx in zip(ys, xs):
y_i = y_idx.item()
x_i = x_idx.item()
conf = heatmap[cls_idx, y_i, x_i].item()
dx = offset_map[0, y_i, x_i].item()
dy = offset_map[1, y_i, x_i].item()
# Back to input space with sub-pixel offset
det_x = (x_i + dx) * stride
det_y = (y_i + dy) * stride
detections.append({
"x": det_x,
"y": det_y,
"class": cls_name,
"conf": conf,
})
return detections
|