File size: 4,211 Bytes
86c24cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""
Loss functions for CenterNet immunogold detection.

Implements CornerNet penalty-reduced focal loss for sparse heatmaps
and smooth L1 offset regression loss.
"""

import torch
import torch.nn.functional as F


def cornernet_focal_loss(
    pred: torch.Tensor,
    gt: torch.Tensor,
    alpha: int = 2,
    beta: int = 4,
    conf_weights: torch.Tensor = None,
    eps: float = 1e-6,
) -> torch.Tensor:
    """
    CornerNet penalty-reduced focal loss for sparse heatmaps.

    The positive:negative pixel ratio is ~1:23,000 per channel.
    Standard BCE would learn to predict all zeros. This loss
    penalizes confident wrong predictions and rewards uncertain
    correct ones via the (1-p)^alpha and p^alpha terms.

    Args:
        pred: (B, C, H, W) sigmoid-activated predictions in [0, 1]
        gt: (B, C, H, W) Gaussian heatmap targets in [0, 1]
        alpha: focal exponent for prediction confidence (default 2)
        beta: penalty reduction exponent near GT peaks (default 4)
        conf_weights: optional (B, C, H, W) per-pixel confidence weights
                      for pseudo-label weighting
        eps: numerical stability

    Returns:
        Scalar loss, normalized by number of positive locations.
    """
    pos_mask = (gt == 1).float()
    neg_mask = (gt < 1).float()

    # Penalty reduction: pixels near particle centers get lower negative penalty
    # (1 - gt)^beta → 0 near peaks, → 1 far from peaks
    neg_weights = torch.pow(1 - gt, beta)

    # Positive loss: encourage high confidence at GT peaks
    pos_loss = torch.log(pred.clamp(min=eps)) * torch.pow(1 - pred, alpha) * pos_mask

    # Negative loss: penalize high confidence away from GT peaks
    neg_loss = (
        torch.log((1 - pred).clamp(min=eps))
        * torch.pow(pred, alpha)
        * neg_weights
        * neg_mask
    )

    # Apply confidence weighting if provided (for pseudo-label support)
    if conf_weights is not None:
        pos_loss = pos_loss * conf_weights
        # Negative loss near pseudo-labels also scaled
        neg_loss = neg_loss * conf_weights

    num_pos = pos_mask.sum().clamp(min=1)
    loss = -(pos_loss.sum() + neg_loss.sum()) / num_pos

    return loss


def offset_loss(
    pred_offsets: torch.Tensor,
    gt_offsets: torch.Tensor,
    mask: torch.Tensor,
) -> torch.Tensor:
    """
    Smooth L1 loss on sub-pixel offsets at annotated particle locations only.

    Args:
        pred_offsets: (B, 2, H, W) predicted offsets
        gt_offsets: (B, 2, H, W) ground truth offsets
        mask: (B, H, W) boolean — True at particle integer centers

    Returns:
        Scalar loss.
    """
    # Expand mask to match offset dimensions
    mask_expanded = mask.unsqueeze(1).expand_as(pred_offsets)

    if mask_expanded.sum() == 0:
        return torch.tensor(0.0, device=pred_offsets.device, requires_grad=True)

    loss = F.smooth_l1_loss(
        pred_offsets[mask_expanded],
        gt_offsets[mask_expanded],
        reduction="mean",
    )
    return loss


def total_loss(
    heatmap_pred: torch.Tensor,
    heatmap_gt: torch.Tensor,
    offset_pred: torch.Tensor,
    offset_gt: torch.Tensor,
    offset_mask: torch.Tensor,
    lambda_offset: float = 1.0,
    focal_alpha: int = 2,
    focal_beta: int = 4,
    conf_weights: torch.Tensor = None,
) -> tuple:
    """
    Combined heatmap focal loss + offset regression loss.

    Args:
        heatmap_pred: (B, 2, H, W) sigmoid predictions
        heatmap_gt: (B, 2, H, W) Gaussian GT
        offset_pred: (B, 2, H, W) predicted offsets
        offset_gt: (B, 2, H, W) GT offsets
        offset_mask: (B, H, W) boolean mask
        lambda_offset: weight for offset loss (default 1.0)
        focal_alpha: focal loss alpha
        focal_beta: focal loss beta
        conf_weights: optional per-pixel confidence weights

    Returns:
        (total_loss, heatmap_loss_value, offset_loss_value)
    """
    l_hm = cornernet_focal_loss(
        heatmap_pred, heatmap_gt,
        alpha=focal_alpha, beta=focal_beta,
        conf_weights=conf_weights,
    )
    l_off = offset_loss(offset_pred, offset_gt, offset_mask)

    total = l_hm + lambda_offset * l_off

    return total, l_hm.item(), l_off.item()