File size: 4,211 Bytes
cf82a19 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 | """
Loss functions for CenterNet immunogold detection.
Implements CornerNet penalty-reduced focal loss for sparse heatmaps
and smooth L1 offset regression loss.
"""
import torch
import torch.nn.functional as F
def cornernet_focal_loss(
pred: torch.Tensor,
gt: torch.Tensor,
alpha: int = 2,
beta: int = 4,
conf_weights: torch.Tensor = None,
eps: float = 1e-6,
) -> torch.Tensor:
"""
CornerNet penalty-reduced focal loss for sparse heatmaps.
The positive:negative pixel ratio is ~1:23,000 per channel.
Standard BCE would learn to predict all zeros. This loss
penalizes confident wrong predictions and rewards uncertain
correct ones via the (1-p)^alpha and p^alpha terms.
Args:
pred: (B, C, H, W) sigmoid-activated predictions in [0, 1]
gt: (B, C, H, W) Gaussian heatmap targets in [0, 1]
alpha: focal exponent for prediction confidence (default 2)
beta: penalty reduction exponent near GT peaks (default 4)
conf_weights: optional (B, C, H, W) per-pixel confidence weights
for pseudo-label weighting
eps: numerical stability
Returns:
Scalar loss, normalized by number of positive locations.
"""
pos_mask = (gt == 1).float()
neg_mask = (gt < 1).float()
# Penalty reduction: pixels near particle centers get lower negative penalty
# (1 - gt)^beta → 0 near peaks, → 1 far from peaks
neg_weights = torch.pow(1 - gt, beta)
# Positive loss: encourage high confidence at GT peaks
pos_loss = torch.log(pred.clamp(min=eps)) * torch.pow(1 - pred, alpha) * pos_mask
# Negative loss: penalize high confidence away from GT peaks
neg_loss = (
torch.log((1 - pred).clamp(min=eps))
* torch.pow(pred, alpha)
* neg_weights
* neg_mask
)
# Apply confidence weighting if provided (for pseudo-label support)
if conf_weights is not None:
pos_loss = pos_loss * conf_weights
# Negative loss near pseudo-labels also scaled
neg_loss = neg_loss * conf_weights
num_pos = pos_mask.sum().clamp(min=1)
loss = -(pos_loss.sum() + neg_loss.sum()) / num_pos
return loss
def offset_loss(
pred_offsets: torch.Tensor,
gt_offsets: torch.Tensor,
mask: torch.Tensor,
) -> torch.Tensor:
"""
Smooth L1 loss on sub-pixel offsets at annotated particle locations only.
Args:
pred_offsets: (B, 2, H, W) predicted offsets
gt_offsets: (B, 2, H, W) ground truth offsets
mask: (B, H, W) boolean — True at particle integer centers
Returns:
Scalar loss.
"""
# Expand mask to match offset dimensions
mask_expanded = mask.unsqueeze(1).expand_as(pred_offsets)
if mask_expanded.sum() == 0:
return torch.tensor(0.0, device=pred_offsets.device, requires_grad=True)
loss = F.smooth_l1_loss(
pred_offsets[mask_expanded],
gt_offsets[mask_expanded],
reduction="mean",
)
return loss
def total_loss(
heatmap_pred: torch.Tensor,
heatmap_gt: torch.Tensor,
offset_pred: torch.Tensor,
offset_gt: torch.Tensor,
offset_mask: torch.Tensor,
lambda_offset: float = 1.0,
focal_alpha: int = 2,
focal_beta: int = 4,
conf_weights: torch.Tensor = None,
) -> tuple:
"""
Combined heatmap focal loss + offset regression loss.
Args:
heatmap_pred: (B, 2, H, W) sigmoid predictions
heatmap_gt: (B, 2, H, W) Gaussian GT
offset_pred: (B, 2, H, W) predicted offsets
offset_gt: (B, 2, H, W) GT offsets
offset_mask: (B, H, W) boolean mask
lambda_offset: weight for offset loss (default 1.0)
focal_alpha: focal loss alpha
focal_beta: focal loss beta
conf_weights: optional per-pixel confidence weights
Returns:
(total_loss, heatmap_loss_value, offset_loss_value)
"""
l_hm = cornernet_focal_loss(
heatmap_pred, heatmap_gt,
alpha=focal_alpha, beta=focal_beta,
conf_weights=conf_weights,
)
l_off = offset_loss(offset_pred, offset_gt, offset_mask)
total = l_hm + lambda_offset * l_off
return total, l_hm.item(), l_off.item()
|