| """ |
| Loss functions for CenterNet immunogold detection. |
| |
| Implements CornerNet penalty-reduced focal loss for sparse heatmaps |
| and smooth L1 offset regression loss. |
| """ |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
|
|
| def cornernet_focal_loss( |
| pred: torch.Tensor, |
| gt: torch.Tensor, |
| alpha: int = 2, |
| beta: int = 4, |
| conf_weights: torch.Tensor = None, |
| eps: float = 1e-6, |
| ) -> torch.Tensor: |
| """ |
| CornerNet penalty-reduced focal loss for sparse heatmaps. |
| |
| The positive:negative pixel ratio is ~1:23,000 per channel. |
| Standard BCE would learn to predict all zeros. This loss |
| penalizes confident wrong predictions and rewards uncertain |
| correct ones via the (1-p)^alpha and p^alpha terms. |
| |
| Args: |
| pred: (B, C, H, W) sigmoid-activated predictions in [0, 1] |
| gt: (B, C, H, W) Gaussian heatmap targets in [0, 1] |
| alpha: focal exponent for prediction confidence (default 2) |
| beta: penalty reduction exponent near GT peaks (default 4) |
| conf_weights: optional (B, C, H, W) per-pixel confidence weights |
| for pseudo-label weighting |
| eps: numerical stability |
| |
| Returns: |
| Scalar loss, normalized by number of positive locations. |
| """ |
| pos_mask = (gt == 1).float() |
| neg_mask = (gt < 1).float() |
|
|
| |
| |
| neg_weights = torch.pow(1 - gt, beta) |
|
|
| |
| pos_loss = torch.log(pred.clamp(min=eps)) * torch.pow(1 - pred, alpha) * pos_mask |
|
|
| |
| neg_loss = ( |
| torch.log((1 - pred).clamp(min=eps)) |
| * torch.pow(pred, alpha) |
| * neg_weights |
| * neg_mask |
| ) |
|
|
| |
| if conf_weights is not None: |
| pos_loss = pos_loss * conf_weights |
| |
| neg_loss = neg_loss * conf_weights |
|
|
| num_pos = pos_mask.sum().clamp(min=1) |
| loss = -(pos_loss.sum() + neg_loss.sum()) / num_pos |
|
|
| return loss |
|
|
|
|
| def offset_loss( |
| pred_offsets: torch.Tensor, |
| gt_offsets: torch.Tensor, |
| mask: torch.Tensor, |
| ) -> torch.Tensor: |
| """ |
| Smooth L1 loss on sub-pixel offsets at annotated particle locations only. |
| |
| Args: |
| pred_offsets: (B, 2, H, W) predicted offsets |
| gt_offsets: (B, 2, H, W) ground truth offsets |
| mask: (B, H, W) boolean — True at particle integer centers |
| |
| Returns: |
| Scalar loss. |
| """ |
| |
| mask_expanded = mask.unsqueeze(1).expand_as(pred_offsets) |
|
|
| if mask_expanded.sum() == 0: |
| return torch.tensor(0.0, device=pred_offsets.device, requires_grad=True) |
|
|
| loss = F.smooth_l1_loss( |
| pred_offsets[mask_expanded], |
| gt_offsets[mask_expanded], |
| reduction="mean", |
| ) |
| return loss |
|
|
|
|
| def total_loss( |
| heatmap_pred: torch.Tensor, |
| heatmap_gt: torch.Tensor, |
| offset_pred: torch.Tensor, |
| offset_gt: torch.Tensor, |
| offset_mask: torch.Tensor, |
| lambda_offset: float = 1.0, |
| focal_alpha: int = 2, |
| focal_beta: int = 4, |
| conf_weights: torch.Tensor = None, |
| ) -> tuple: |
| """ |
| Combined heatmap focal loss + offset regression loss. |
| |
| Args: |
| heatmap_pred: (B, 2, H, W) sigmoid predictions |
| heatmap_gt: (B, 2, H, W) Gaussian GT |
| offset_pred: (B, 2, H, W) predicted offsets |
| offset_gt: (B, 2, H, W) GT offsets |
| offset_mask: (B, H, W) boolean mask |
| lambda_offset: weight for offset loss (default 1.0) |
| focal_alpha: focal loss alpha |
| focal_beta: focal loss beta |
| conf_weights: optional per-pixel confidence weights |
| |
| Returns: |
| (total_loss, heatmap_loss_value, offset_loss_value) |
| """ |
| l_hm = cornernet_focal_loss( |
| heatmap_pred, heatmap_gt, |
| alpha=focal_alpha, beta=focal_beta, |
| conf_weights=conf_weights, |
| ) |
| l_off = offset_loss(offset_pred, offset_gt, offset_mask) |
|
|
| total = l_hm + lambda_offset * l_off |
|
|
| return total, l_hm.item(), l_off.item() |
|
|