dboris commited on
Commit
579586d
·
verified ·
1 Parent(s): 02db62d

Upload src/losses.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/losses.py +165 -0
src/losses.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Loss functions for fine-grained classification.
3
+
4
+ ArcFace: Angular margin loss — forces angular separation between breed embeddings.
5
+ Poly-1: Drop-in CE replacement with polynomial adjustment.
6
+ """
7
+
8
+ import math
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ class ArcFaceLoss(nn.Module):
15
+ """ArcFace Additive Angular Margin Loss.
16
+
17
+ Projects features onto a hypersphere and enforces angular margin
18
+ between classes. Excellent for fine-grained classification where
19
+ visually similar classes (e.g., Staffordshire vs AmStaff) need
20
+ strong discriminative boundaries.
21
+
22
+ Args:
23
+ embed_dim: Feature embedding dimension
24
+ num_classes: Number of classes
25
+ scale: Feature scale (s). Default: 30.0
26
+ margin: Angular margin (m) in radians. Default: 0.3
27
+ label_smoothing: Smoothing factor. Default: 0.0
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ embed_dim: int,
33
+ num_classes: int,
34
+ scale: float = 30.0,
35
+ margin: float = 0.3,
36
+ label_smoothing: float = 0.0,
37
+ ):
38
+ super().__init__()
39
+ self.scale = scale
40
+ self.margin = margin
41
+ self.label_smoothing = label_smoothing
42
+ self.num_classes = num_classes
43
+
44
+ # Learnable class weight vectors (on unit hypersphere)
45
+ self.weight = nn.Parameter(torch.FloatTensor(num_classes, embed_dim))
46
+ nn.init.xavier_uniform_(self.weight)
47
+
48
+ # Precompute margin terms
49
+ self.cos_m = math.cos(margin)
50
+ self.sin_m = math.sin(margin)
51
+ self.th = math.cos(math.pi - margin)
52
+ self.mm = math.sin(math.pi - margin) * margin
53
+
54
+ def forward(self, embeddings: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
55
+ """
56
+ Args:
57
+ embeddings: (B, embed_dim) — raw features from backbone (NOT logits)
58
+ labels: (B,) — ground truth class indices
59
+ """
60
+ # Normalize embeddings and weights to unit hypersphere
61
+ embeddings = F.normalize(embeddings, p=2, dim=1)
62
+ weight = F.normalize(self.weight, p=2, dim=1)
63
+
64
+ # Cosine similarity (dot product on unit sphere)
65
+ cosine = F.linear(embeddings, weight) # (B, num_classes)
66
+ sine = torch.sqrt(1.0 - torch.clamp(cosine * cosine, 0, 1))
67
+
68
+ # cos(θ + m) = cos(θ)cos(m) - sin(θ)sin(m)
69
+ phi = cosine * self.cos_m - sine * self.sin_m
70
+
71
+ # Numerical safety: when cos(θ) < cos(π - m), use linearized version
72
+ phi = torch.where(cosine > self.th, phi, cosine - self.mm)
73
+
74
+ # One-hot encode labels
75
+ one_hot = torch.zeros_like(cosine)
76
+ one_hot.scatter_(1, labels.view(-1, 1).long(), 1)
77
+
78
+ # Apply margin only to the target class
79
+ output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
80
+ output *= self.scale
81
+
82
+ # Standard cross-entropy with optional label smoothing
83
+ return F.cross_entropy(output, labels, label_smoothing=self.label_smoothing)
84
+
85
+
86
+ class ArcFaceHead(nn.Module):
87
+ """Combined ArcFace projection head — replaces the standard MLP + CE pipeline.
88
+
89
+ Takes raw backbone features, projects to embedding space, then applies ArcFace.
90
+ During inference, use the projected embeddings for classification via cosine similarity.
91
+ """
92
+
93
+ def __init__(
94
+ self,
95
+ embed_dim: int,
96
+ num_classes: int,
97
+ projection_dim: int = 512,
98
+ scale: float = 30.0,
99
+ margin: float = 0.3,
100
+ dropout: float = 0.3,
101
+ ):
102
+ super().__init__()
103
+ self.projector = nn.Sequential(
104
+ nn.LayerNorm(embed_dim),
105
+ nn.Linear(embed_dim, projection_dim),
106
+ nn.GELU(),
107
+ nn.Dropout(dropout),
108
+ )
109
+ self.arcface = ArcFaceLoss(
110
+ embed_dim=projection_dim,
111
+ num_classes=num_classes,
112
+ scale=scale,
113
+ margin=margin,
114
+ )
115
+ self.num_classes = num_classes
116
+
117
+ def forward(self, features: torch.Tensor, labels: torch.Tensor = None):
118
+ """
119
+ During training (labels provided): returns ArcFace loss
120
+ During inference (no labels): returns cosine similarity logits
121
+ """
122
+ projected = self.projector(features)
123
+
124
+ if labels is not None:
125
+ # Training mode: return loss
126
+ return self.arcface(projected, labels)
127
+ else:
128
+ # Inference mode: return cosine similarity as logits
129
+ projected = F.normalize(projected, p=2, dim=1)
130
+ weight = F.normalize(self.arcface.weight, p=2, dim=1)
131
+ return F.linear(projected, weight) * self.arcface.scale
132
+
133
+
134
+ class Poly1Loss(nn.Module):
135
+ """Poly-1 Cross-Entropy Loss.
136
+
137
+ Near drop-in replacement for CE. Adds a polynomial correction term
138
+ that helps with hard examples. From "PolyLoss" paper (ICLR 2022).
139
+
140
+ Args:
141
+ num_classes: Number of classes
142
+ epsilon: Polynomial coefficient. Default: 1.0
143
+ label_smoothing: Smoothing factor. Default: 0.1
144
+ """
145
+
146
+ def __init__(self, num_classes: int = 120, epsilon: float = 1.0, label_smoothing: float = 0.1):
147
+ super().__init__()
148
+ self.epsilon = epsilon
149
+ self.num_classes = num_classes
150
+ self.label_smoothing = label_smoothing
151
+
152
+ def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
153
+ ce_loss = F.cross_entropy(logits, labels, label_smoothing=self.label_smoothing)
154
+
155
+ # Poly-1 adjustment
156
+ probs = F.softmax(logits, dim=1)
157
+ one_hot = F.one_hot(labels, self.num_classes).float()
158
+
159
+ if self.label_smoothing > 0:
160
+ one_hot = one_hot * (1 - self.label_smoothing) + self.label_smoothing / self.num_classes
161
+
162
+ pt = (probs * one_hot).sum(dim=1) # Probability of true class
163
+ poly1 = ce_loss + self.epsilon * (1 - pt).mean()
164
+
165
+ return poly1