Zorrojurro commited on
Commit
0587d11
·
verified ·
1 Parent(s): 24ee5e9

Upload src/training/losses.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/training/losses.py +200 -0
src/training/losses.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom loss functions for thermal pattern analysis training.
3
+
4
+ Implements:
5
+ - ContrastiveLoss — pushes same-class pairs together, different-class apart
6
+ - TripletLoss — anchor / positive / negative margin ranking
7
+ - CombinedLoss — weighted sum of both
8
+ """
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+
15
+ class ContrastiveLoss(nn.Module):
16
+ """
17
+ Contrastive loss (Chopra et al., 2005).
18
+
19
+ For a pair of embeddings (e1, e2) with label y ∈ {0, 1}:
20
+ y=0 → same class → loss = ½ · D²
21
+ y=1 → diff class → loss = ½ · max(0, margin − D)²
22
+ where D = ‖e1 − e2‖₂.
23
+ """
24
+
25
+ def __init__(self, margin: float = 1.0):
26
+ super().__init__()
27
+ self.margin = margin
28
+
29
+ def forward(
30
+ self,
31
+ embeddings1: torch.Tensor,
32
+ embeddings2: torch.Tensor,
33
+ labels: torch.Tensor,
34
+ ) -> torch.Tensor:
35
+ """
36
+ Args:
37
+ embeddings1: (B, D)
38
+ embeddings2: (B, D)
39
+ labels: (B,) — 0 if same class, 1 if different
40
+
41
+ Returns:
42
+ Scalar loss.
43
+ """
44
+ distance = F.pairwise_distance(embeddings1, embeddings2)
45
+ loss = (
46
+ (1 - labels) * distance.pow(2)
47
+ + labels * F.relu(self.margin - distance).pow(2)
48
+ )
49
+ return 0.5 * loss.mean()
50
+
51
+
52
+ class TripletLoss(nn.Module):
53
+ """
54
+ Triplet margin loss with optional hard-negative mining.
55
+
56
+ loss = max(0, d(a, p) − d(a, n) + margin)
57
+ """
58
+
59
+ def __init__(self, margin: float = 1.0):
60
+ super().__init__()
61
+ self.loss_fn = nn.TripletMarginLoss(margin=margin, p=2)
62
+
63
+ def forward(
64
+ self,
65
+ anchor: torch.Tensor,
66
+ positive: torch.Tensor,
67
+ negative: torch.Tensor,
68
+ ) -> torch.Tensor:
69
+ """
70
+ Args:
71
+ anchor: (B, D)
72
+ positive: (B, D) — same class as anchor
73
+ negative: (B, D) — different class from anchor
74
+
75
+ Returns:
76
+ Scalar loss.
77
+ """
78
+ return self.loss_fn(anchor, positive, negative)
79
+
80
+
81
+ class CombinedLoss(nn.Module):
82
+ """
83
+ Weighted combination of Contrastive and Triplet losses,
84
+ with a standard cross-entropy classification head.
85
+
86
+ total = α·contrastive + β·triplet + γ·classification
87
+ """
88
+
89
+ def __init__(
90
+ self,
91
+ contrastive_weight: float = 0.3,
92
+ triplet_weight: float = 0.3,
93
+ classification_weight: float = 0.4,
94
+ triplet_margin: float = 1.0,
95
+ contrastive_margin: float = 1.0,
96
+ ):
97
+ super().__init__()
98
+ self.contrastive_weight = contrastive_weight
99
+ self.triplet_weight = triplet_weight
100
+ self.classification_weight = classification_weight
101
+
102
+ self.contrastive_loss = ContrastiveLoss(margin=contrastive_margin)
103
+ self.triplet_loss = TripletLoss(margin=triplet_margin)
104
+ self.classification_loss = nn.CrossEntropyLoss()
105
+
106
+ @classmethod
107
+ def from_config(cls, config) -> "CombinedLoss":
108
+ """Construct from a Config object."""
109
+ loss_cfg = config.training.loss
110
+ return cls(
111
+ contrastive_weight=loss_cfg.contrastive_weight,
112
+ triplet_weight=loss_cfg.triplet_weight,
113
+ classification_weight=1.0 - loss_cfg.contrastive_weight - loss_cfg.triplet_weight,
114
+ triplet_margin=loss_cfg.triplet_margin,
115
+ )
116
+
117
+ def forward(
118
+ self,
119
+ embeddings: torch.Tensor,
120
+ labels: torch.Tensor,
121
+ logits: torch.Tensor | None = None,
122
+ ) -> dict:
123
+ """
124
+ Compute the combined loss.
125
+
126
+ Uses in-batch pair and triplet mining for efficiency.
127
+
128
+ Args:
129
+ embeddings: (B, D)
130
+ labels: (B,) integer class labels
131
+ logits: (B, num_classes) or None
132
+
133
+ Returns:
134
+ dict with total_loss, contrastive, triplet, classification.
135
+ """
136
+ total = torch.tensor(0.0, device=embeddings.device)
137
+ result = {}
138
+
139
+ # ------- Contrastive: generate in-batch pairs -------
140
+ B = embeddings.size(0)
141
+ if B >= 2:
142
+ idx = torch.randperm(B, device=embeddings.device)
143
+ e1, e2 = embeddings, embeddings[idx]
144
+ pair_labels = (labels != labels[idx]).float()
145
+
146
+ c_loss = self.contrastive_loss(e1, e2, pair_labels)
147
+ total = total + self.contrastive_weight * c_loss
148
+ result["contrastive"] = c_loss.item()
149
+
150
+ # ------- Triplet: mine anchor / pos / neg -------
151
+ anchors, positives, negatives = self._mine_triplets(embeddings, labels)
152
+ if anchors is not None:
153
+ t_loss = self.triplet_loss(anchors, positives, negatives)
154
+ total = total + self.triplet_weight * t_loss
155
+ result["triplet"] = t_loss.item()
156
+
157
+ # ------- Classification -------
158
+ if logits is not None:
159
+ cls_loss = self.classification_loss(logits, labels)
160
+ total = total + self.classification_weight * cls_loss
161
+ result["classification"] = cls_loss.item()
162
+
163
+ result["total_loss"] = total
164
+ return result
165
+
166
+ @staticmethod
167
+ def _mine_triplets(
168
+ embeddings: torch.Tensor, labels: torch.Tensor
169
+ ) -> tuple:
170
+ """Simple in-batch triplet mining."""
171
+ unique_labels = labels.unique()
172
+ if len(unique_labels) < 2:
173
+ return None, None, None
174
+
175
+ anchors, positives, negatives = [], [], []
176
+
177
+ for label in unique_labels:
178
+ mask_pos = labels == label
179
+ mask_neg = labels != label
180
+
181
+ pos_idx = mask_pos.nonzero(as_tuple=True)[0]
182
+ neg_idx = mask_neg.nonzero(as_tuple=True)[0]
183
+
184
+ if len(pos_idx) < 2 or len(neg_idx) < 1:
185
+ continue
186
+
187
+ for i in range(min(len(pos_idx) - 1, 4)): # limit per class
188
+ anchors.append(embeddings[pos_idx[i]])
189
+ positives.append(embeddings[pos_idx[i + 1]])
190
+ neg_i = neg_idx[torch.randint(len(neg_idx), (1,)).item()]
191
+ negatives.append(embeddings[neg_i])
192
+
193
+ if not anchors:
194
+ return None, None, None
195
+
196
+ return (
197
+ torch.stack(anchors),
198
+ torch.stack(positives),
199
+ torch.stack(negatives),
200
+ )