AnikS22 commited on
Commit
cf82a19
·
verified ·
1 Parent(s): 88a76dc

Upload src/loss.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/loss.py +137 -0
src/loss.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Loss functions for CenterNet immunogold detection.
3
+
4
+ Implements CornerNet penalty-reduced focal loss for sparse heatmaps
5
+ and smooth L1 offset regression loss.
6
+ """
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+
11
+
12
+ def cornernet_focal_loss(
13
+ pred: torch.Tensor,
14
+ gt: torch.Tensor,
15
+ alpha: int = 2,
16
+ beta: int = 4,
17
+ conf_weights: torch.Tensor = None,
18
+ eps: float = 1e-6,
19
+ ) -> torch.Tensor:
20
+ """
21
+ CornerNet penalty-reduced focal loss for sparse heatmaps.
22
+
23
+ The positive:negative pixel ratio is ~1:23,000 per channel.
24
+ Standard BCE would learn to predict all zeros. This loss
25
+ penalizes confident wrong predictions and rewards uncertain
26
+ correct ones via the (1-p)^alpha and p^alpha terms.
27
+
28
+ Args:
29
+ pred: (B, C, H, W) sigmoid-activated predictions in [0, 1]
30
+ gt: (B, C, H, W) Gaussian heatmap targets in [0, 1]
31
+ alpha: focal exponent for prediction confidence (default 2)
32
+ beta: penalty reduction exponent near GT peaks (default 4)
33
+ conf_weights: optional (B, C, H, W) per-pixel confidence weights
34
+ for pseudo-label weighting
35
+ eps: numerical stability
36
+
37
+ Returns:
38
+ Scalar loss, normalized by number of positive locations.
39
+ """
40
+ pos_mask = (gt == 1).float()
41
+ neg_mask = (gt < 1).float()
42
+
43
+ # Penalty reduction: pixels near particle centers get lower negative penalty
44
+ # (1 - gt)^beta → 0 near peaks, → 1 far from peaks
45
+ neg_weights = torch.pow(1 - gt, beta)
46
+
47
+ # Positive loss: encourage high confidence at GT peaks
48
+ pos_loss = torch.log(pred.clamp(min=eps)) * torch.pow(1 - pred, alpha) * pos_mask
49
+
50
+ # Negative loss: penalize high confidence away from GT peaks
51
+ neg_loss = (
52
+ torch.log((1 - pred).clamp(min=eps))
53
+ * torch.pow(pred, alpha)
54
+ * neg_weights
55
+ * neg_mask
56
+ )
57
+
58
+ # Apply confidence weighting if provided (for pseudo-label support)
59
+ if conf_weights is not None:
60
+ pos_loss = pos_loss * conf_weights
61
+ # Negative loss near pseudo-labels also scaled
62
+ neg_loss = neg_loss * conf_weights
63
+
64
+ num_pos = pos_mask.sum().clamp(min=1)
65
+ loss = -(pos_loss.sum() + neg_loss.sum()) / num_pos
66
+
67
+ return loss
68
+
69
+
70
+ def offset_loss(
71
+ pred_offsets: torch.Tensor,
72
+ gt_offsets: torch.Tensor,
73
+ mask: torch.Tensor,
74
+ ) -> torch.Tensor:
75
+ """
76
+ Smooth L1 loss on sub-pixel offsets at annotated particle locations only.
77
+
78
+ Args:
79
+ pred_offsets: (B, 2, H, W) predicted offsets
80
+ gt_offsets: (B, 2, H, W) ground truth offsets
81
+ mask: (B, H, W) boolean — True at particle integer centers
82
+
83
+ Returns:
84
+ Scalar loss.
85
+ """
86
+ # Expand mask to match offset dimensions
87
+ mask_expanded = mask.unsqueeze(1).expand_as(pred_offsets)
88
+
89
+ if mask_expanded.sum() == 0:
90
+ return torch.tensor(0.0, device=pred_offsets.device, requires_grad=True)
91
+
92
+ loss = F.smooth_l1_loss(
93
+ pred_offsets[mask_expanded],
94
+ gt_offsets[mask_expanded],
95
+ reduction="mean",
96
+ )
97
+ return loss
98
+
99
+
100
+ def total_loss(
101
+ heatmap_pred: torch.Tensor,
102
+ heatmap_gt: torch.Tensor,
103
+ offset_pred: torch.Tensor,
104
+ offset_gt: torch.Tensor,
105
+ offset_mask: torch.Tensor,
106
+ lambda_offset: float = 1.0,
107
+ focal_alpha: int = 2,
108
+ focal_beta: int = 4,
109
+ conf_weights: torch.Tensor = None,
110
+ ) -> tuple:
111
+ """
112
+ Combined heatmap focal loss + offset regression loss.
113
+
114
+ Args:
115
+ heatmap_pred: (B, 2, H, W) sigmoid predictions
116
+ heatmap_gt: (B, 2, H, W) Gaussian GT
117
+ offset_pred: (B, 2, H, W) predicted offsets
118
+ offset_gt: (B, 2, H, W) GT offsets
119
+ offset_mask: (B, H, W) boolean mask
120
+ lambda_offset: weight for offset loss (default 1.0)
121
+ focal_alpha: focal loss alpha
122
+ focal_beta: focal loss beta
123
+ conf_weights: optional per-pixel confidence weights
124
+
125
+ Returns:
126
+ (total_loss, heatmap_loss_value, offset_loss_value)
127
+ """
128
+ l_hm = cornernet_focal_loss(
129
+ heatmap_pred, heatmap_gt,
130
+ alpha=focal_alpha, beta=focal_beta,
131
+ conf_weights=conf_weights,
132
+ )
133
+ l_off = offset_loss(offset_pred, offset_gt, offset_mask)
134
+
135
+ total = l_hm + lambda_offset * l_off
136
+
137
+ return total, l_hm.item(), l_off.item()