| """ |
| Ground truth heatmap generation and peak extraction for CenterNet. |
| |
| Generates Gaussian-splat heatmaps at stride-2 resolution with |
| class-specific sigma values calibrated to bead size. |
| """ |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from typing import Dict, List, Tuple, Optional |
|
|
| |
| CLASS_IDX = {"6nm": 0, "12nm": 1} |
| CLASS_NAMES = ["6nm", "12nm"] |
| STRIDE = 2 |
|
|
|
|
| def generate_heatmap_gt( |
| coords_6nm: np.ndarray, |
| coords_12nm: np.ndarray, |
| image_h: int, |
| image_w: int, |
| sigmas: Optional[Dict[str, float]] = None, |
| stride: int = STRIDE, |
| confidence_6nm: Optional[np.ndarray] = None, |
| confidence_12nm: Optional[np.ndarray] = None, |
| ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: |
| """ |
| Generate CenterNet ground truth heatmaps and offset maps. |
| |
| Args: |
| coords_6nm: (N, 2) array of (x, y) in ORIGINAL pixel space |
| coords_12nm: (M, 2) array of (x, y) in ORIGINAL pixel space |
| image_h: original image height |
| image_w: original image width |
| sigmas: per-class Gaussian sigma in feature space |
| stride: output stride (default 2) |
| confidence_6nm: optional per-particle confidence weights |
| confidence_12nm: optional per-particle confidence weights |
| |
| Returns: |
| heatmap: (2, H//stride, W//stride) float32 in [0, 1] |
| offsets: (2, H//stride, W//stride) float32 sub-pixel offsets |
| offset_mask: (H//stride, W//stride) bool — True at particle centers |
| conf_map: (2, H//stride, W//stride) float32 confidence weights |
| """ |
| if sigmas is None: |
| sigmas = {"6nm": 1.0, "12nm": 1.5} |
|
|
| h_feat = image_h // stride |
| w_feat = image_w // stride |
|
|
| heatmap = np.zeros((2, h_feat, w_feat), dtype=np.float32) |
| offsets = np.zeros((2, h_feat, w_feat), dtype=np.float32) |
| offset_mask = np.zeros((h_feat, w_feat), dtype=bool) |
| conf_map = np.ones((2, h_feat, w_feat), dtype=np.float32) |
|
|
| |
| all_entries = [] |
| if len(coords_6nm) > 0: |
| confs = confidence_6nm if confidence_6nm is not None else np.ones(len(coords_6nm)) |
| for i, (x, y) in enumerate(coords_6nm): |
| all_entries.append((x, y, "6nm", confs[i])) |
| if len(coords_12nm) > 0: |
| confs = confidence_12nm if confidence_12nm is not None else np.ones(len(coords_12nm)) |
| for i, (x, y) in enumerate(coords_12nm): |
| all_entries.append((x, y, "12nm", confs[i])) |
|
|
| for x, y, cls, conf in all_entries: |
| cidx = CLASS_IDX[cls] |
| sigma = sigmas[cls] |
|
|
| |
| cx_f = x / stride |
| cy_f = y / stride |
|
|
| |
| cx_i = int(round(cx_f)) |
| cy_i = int(round(cy_f)) |
|
|
| |
| off_x = cx_f - cx_i |
| off_y = cy_f - cy_i |
|
|
| |
| r = max(int(3 * sigma + 1), 2) |
|
|
| |
| y0 = max(0, cy_i - r) |
| y1 = min(h_feat, cy_i + r + 1) |
| x0 = max(0, cx_i - r) |
| x1 = min(w_feat, cx_i + r + 1) |
|
|
| if y0 >= y1 or x0 >= x1: |
| continue |
|
|
| yy, xx = np.meshgrid( |
| np.arange(y0, y1), |
| np.arange(x0, x1), |
| indexing="ij", |
| ) |
|
|
| |
| |
| |
| |
| |
| gaussian = np.exp( |
| -((xx - cx_i) ** 2 + (yy - cy_i) ** 2) / (2 * sigma ** 2) |
| ) |
|
|
| |
| gaussian = gaussian * conf |
|
|
| |
| heatmap[cidx, y0:y1, x0:x1] = np.maximum( |
| heatmap[cidx, y0:y1, x0:x1], gaussian |
| ) |
|
|
| |
| if 0 <= cy_i < h_feat and 0 <= cx_i < w_feat: |
| offsets[0, cy_i, cx_i] = off_x |
| offsets[1, cy_i, cx_i] = off_y |
| offset_mask[cy_i, cx_i] = True |
| conf_map[cidx, cy_i, cx_i] = conf |
|
|
| return heatmap, offsets, offset_mask, conf_map |
|
|
|
|
| def extract_peaks( |
| heatmap: torch.Tensor, |
| offset_map: torch.Tensor, |
| stride: int = STRIDE, |
| conf_threshold: float = 0.3, |
| nms_kernel_sizes: Optional[Dict[str, int]] = None, |
| ) -> List[dict]: |
| """ |
| Extract detections from predicted heatmap via max-pool NMS. |
| |
| Args: |
| heatmap: (2, H/stride, W/stride) sigmoid-activated |
| offset_map: (2, H/stride, W/stride) raw offset predictions |
| stride: output stride |
| conf_threshold: minimum confidence to keep |
| nms_kernel_sizes: per-class NMS kernel sizes |
| |
| Returns: |
| List of {'x': float, 'y': float, 'class': str, 'conf': float} |
| """ |
| if nms_kernel_sizes is None: |
| nms_kernel_sizes = {"6nm": 3, "12nm": 5} |
|
|
| detections = [] |
|
|
| for cls_idx, cls_name in enumerate(CLASS_NAMES): |
| hm_cls = heatmap[cls_idx].unsqueeze(0).unsqueeze(0) |
| kernel = nms_kernel_sizes[cls_name] |
|
|
| |
| hmax = F.max_pool2d( |
| hm_cls, kernel_size=kernel, stride=1, padding=kernel // 2 |
| ) |
| peaks = (hmax.squeeze() == heatmap[cls_idx]) & ( |
| heatmap[cls_idx] > conf_threshold |
| ) |
|
|
| ys, xs = torch.where(peaks) |
| for y_idx, x_idx in zip(ys, xs): |
| y_i = y_idx.item() |
| x_i = x_idx.item() |
| conf = heatmap[cls_idx, y_i, x_i].item() |
| dx = offset_map[0, y_i, x_i].item() |
| dy = offset_map[1, y_i, x_i].item() |
|
|
| |
| det_x = (x_i + dx) * stride |
| det_y = (y_i + dy) * stride |
|
|
| detections.append({ |
| "x": det_x, |
| "y": det_y, |
| "class": cls_name, |
| "conf": conf, |
| }) |
|
|
| return detections |
|
|