""" Artist Style Embedding - Loss Functions ArcFace + Multi-Similarity + Center Loss """ import math from typing import Tuple import torch import torch.nn as nn import torch.nn.functional as F class ArcFaceLoss(nn.Module): """ArcFace Loss (Additive Angular Margin Loss)""" def __init__(self, scale: float = 64.0, margin: float = 0.5): super().__init__() self.scale = scale self.margin = margin self.cos_m = math.cos(margin) self.sin_m = math.sin(margin) self.th = math.cos(math.pi - margin) self.mm = math.sin(math.pi - margin) * margin def forward(self, cosine: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: sine = torch.sqrt(1.0 - torch.pow(cosine, 2)) phi = cosine * self.cos_m - sine * self.sin_m phi = torch.where(cosine > self.th, phi, cosine - self.mm) one_hot = torch.zeros_like(cosine) one_hot.scatter_(1, labels.view(-1, 1), 1) output = (one_hot * phi) + ((1.0 - one_hot) * cosine) output *= self.scale return F.cross_entropy(output, labels) class MultiSimilarityLoss(nn.Module): """Multi-Similarity Loss for hard sample mining""" def __init__(self, alpha: float = 2.0, beta: float = 50.0, base: float = 0.5): super().__init__() self.alpha = alpha self.beta = beta self.base = base self.margin = 0.1 def forward(self, embeddings: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: sim_mat = torch.matmul(embeddings, embeddings.t()) labels = labels.view(-1, 1) pos_mask = (labels == labels.t()).float() neg_mask = (labels != labels.t()).float() pos_mask.fill_diagonal_(0) loss = 0.0 num_valid = 0 for i in range(embeddings.size(0)): pos_pair = sim_mat[i][pos_mask[i] == 1] neg_pair = sim_mat[i][neg_mask[i] == 1] if len(pos_pair) == 0 or len(neg_pair) == 0: continue # Hard mining with safety checks pos_min = pos_pair.min() neg_max = neg_pair.max() hard_neg = neg_pair[neg_pair + self.margin > pos_min] hard_pos = pos_pair[pos_pair - self.margin < neg_max] if len(hard_pos) == 0 or len(hard_neg) == 0: continue pos_loss = (1.0 / self.alpha) * torch.log( 1 + torch.sum(torch.exp(-self.alpha * (hard_pos - self.base))) ) neg_loss = (1.0 / self.beta) * torch.log( 1 + torch.sum(torch.exp(self.beta * (hard_neg - self.base))) ) loss += pos_loss + neg_loss num_valid += 1 if num_valid == 0: return torch.tensor(0.0, device=embeddings.device, requires_grad=True) return loss / num_valid class CenterLoss(nn.Module): """Center Loss for intra-class compactness""" def __init__(self, num_classes: int, feat_dim: int): super().__init__() self.centers = nn.Parameter(torch.randn(num_classes, feat_dim)) nn.init.xavier_uniform_(self.centers) def forward(self, embeddings: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: centers_batch = self.centers[labels] return torch.pow(embeddings - centers_batch, 2).sum(dim=1).mean() class CombinedMetricLoss(nn.Module): """Combined loss: ArcFace + Multi-Similarity + Center""" def __init__( self, num_classes: int, embedding_dim: int = 512, arcface_scale: float = 64.0, arcface_margin: float = 0.5, arcface_weight: float = 0.2, ms_weight: float = 3.0, center_weight: float = 0.01, ): super().__init__() self.arcface = ArcFaceLoss(scale=arcface_scale, margin=arcface_margin) self.ms_loss = MultiSimilarityLoss() self.center_loss = CenterLoss(num_classes=num_classes, feat_dim=embedding_dim) self.arcface_weight = arcface_weight self.ms_weight = ms_weight self.center_weight = center_weight def forward( self, embeddings: torch.Tensor, cosine: torch.Tensor, labels: torch.Tensor, ) -> Tuple[torch.Tensor, dict]: loss_arc = self.arcface(cosine, labels) loss_ms = self.ms_loss(embeddings, labels) loss_center = self.center_loss(embeddings, labels) total = ( self.arcface_weight * loss_arc + self.ms_weight * loss_ms + self.center_weight * loss_center ) return total, { 'loss_total': total.item(), 'loss_arcface': loss_arc.item(), 'loss_ms': loss_ms.item(), 'loss_center': loss_center.item(), } def create_loss(config, num_classes: int) -> CombinedMetricLoss: return CombinedMetricLoss( num_classes=num_classes, embedding_dim=config.model.embedding_dim, arcface_scale=config.loss.arcface_scale, arcface_margin=config.loss.arcface_margin, arcface_weight=config.loss.arcface_weight, ms_weight=config.loss.ms_loss_weight, center_weight=config.loss.center_loss_weight, )