| """ |
| Loss functions for fine-grained classification. |
| |
| ArcFace: Angular margin loss — forces angular separation between breed embeddings. |
| Poly-1: Drop-in CE replacement with polynomial adjustment. |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class ArcFaceLoss(nn.Module): |
| """ArcFace Additive Angular Margin Loss. |
| |
| Projects features onto a hypersphere and enforces angular margin |
| between classes. Excellent for fine-grained classification where |
| visually similar classes (e.g., Staffordshire vs AmStaff) need |
| strong discriminative boundaries. |
| |
| Args: |
| embed_dim: Feature embedding dimension |
| num_classes: Number of classes |
| scale: Feature scale (s). Default: 30.0 |
| margin: Angular margin (m) in radians. Default: 0.3 |
| label_smoothing: Smoothing factor. Default: 0.0 |
| """ |
|
|
| def __init__( |
| self, |
| embed_dim: int, |
| num_classes: int, |
| scale: float = 30.0, |
| margin: float = 0.3, |
| label_smoothing: float = 0.0, |
| ): |
| super().__init__() |
| self.scale = scale |
| self.margin = margin |
| self.label_smoothing = label_smoothing |
| self.num_classes = num_classes |
|
|
| |
| self.weight = nn.Parameter(torch.FloatTensor(num_classes, embed_dim)) |
| nn.init.xavier_uniform_(self.weight) |
|
|
| |
| self.cos_m = math.cos(margin) |
| self.sin_m = math.sin(margin) |
| self.th = math.cos(math.pi - margin) |
| self.mm = math.sin(math.pi - margin) * margin |
|
|
| def forward(self, embeddings: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| embeddings: (B, embed_dim) — raw features from backbone (NOT logits) |
| labels: (B,) — ground truth class indices |
| """ |
| |
| embeddings = F.normalize(embeddings, p=2, dim=1) |
| weight = F.normalize(self.weight, p=2, dim=1) |
|
|
| |
| cosine = F.linear(embeddings, weight) |
| sine = torch.sqrt(1.0 - torch.clamp(cosine * cosine, 0, 1)) |
|
|
| |
| phi = cosine * self.cos_m - sine * self.sin_m |
|
|
| |
| phi = torch.where(cosine > self.th, phi, cosine - self.mm) |
|
|
| |
| one_hot = torch.zeros_like(cosine) |
| one_hot.scatter_(1, labels.view(-1, 1).long(), 1) |
|
|
| |
| output = (one_hot * phi) + ((1.0 - one_hot) * cosine) |
| output *= self.scale |
|
|
| |
| return F.cross_entropy(output, labels, label_smoothing=self.label_smoothing) |
|
|
|
|
| class ArcFaceHead(nn.Module): |
| """Combined ArcFace projection head — replaces the standard MLP + CE pipeline. |
| |
| Takes raw backbone features, projects to embedding space, then applies ArcFace. |
| During inference, use the projected embeddings for classification via cosine similarity. |
| """ |
|
|
| def __init__( |
| self, |
| embed_dim: int, |
| num_classes: int, |
| projection_dim: int = 512, |
| scale: float = 30.0, |
| margin: float = 0.3, |
| dropout: float = 0.3, |
| ): |
| super().__init__() |
| self.projector = nn.Sequential( |
| nn.LayerNorm(embed_dim), |
| nn.Linear(embed_dim, projection_dim), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| ) |
| self.arcface = ArcFaceLoss( |
| embed_dim=projection_dim, |
| num_classes=num_classes, |
| scale=scale, |
| margin=margin, |
| ) |
| self.num_classes = num_classes |
|
|
| def forward(self, features: torch.Tensor, labels: torch.Tensor = None): |
| """ |
| During training (labels provided): returns ArcFace loss |
| During inference (no labels): returns cosine similarity logits |
| """ |
| projected = self.projector(features) |
|
|
| if labels is not None: |
| |
| return self.arcface(projected, labels) |
| else: |
| |
| projected = F.normalize(projected, p=2, dim=1) |
| weight = F.normalize(self.arcface.weight, p=2, dim=1) |
| return F.linear(projected, weight) * self.arcface.scale |
|
|
|
|
| class Poly1Loss(nn.Module): |
| """Poly-1 Cross-Entropy Loss. |
| |
| Near drop-in replacement for CE. Adds a polynomial correction term |
| that helps with hard examples. From "PolyLoss" paper (ICLR 2022). |
| |
| Args: |
| num_classes: Number of classes |
| epsilon: Polynomial coefficient. Default: 1.0 |
| label_smoothing: Smoothing factor. Default: 0.1 |
| """ |
|
|
| def __init__(self, num_classes: int = 120, epsilon: float = 1.0, label_smoothing: float = 0.1): |
| super().__init__() |
| self.epsilon = epsilon |
| self.num_classes = num_classes |
| self.label_smoothing = label_smoothing |
|
|
| def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: |
| ce_loss = F.cross_entropy(logits, labels, label_smoothing=self.label_smoothing) |
|
|
| |
| probs = F.softmax(logits, dim=1) |
| one_hot = F.one_hot(labels, self.num_classes).float() |
|
|
| if self.label_smoothing > 0: |
| one_hot = one_hot * (1 - self.label_smoothing) + self.label_smoothing / self.num_classes |
|
|
| pt = (probs * one_hot).sum(dim=1) |
| poly1 = ce_loss + self.epsilon * (1 - pt).mean() |
|
|
| return poly1 |
|
|