|
|
""" |
|
|
Artist Style Embedding - Loss Functions |
|
|
ArcFace + Multi-Similarity + Center Loss |
|
|
""" |
|
|
import math |
|
|
from typing import Tuple |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
class ArcFaceLoss(nn.Module): |
|
|
"""ArcFace Loss (Additive Angular Margin Loss)""" |
|
|
|
|
|
def __init__(self, scale: float = 64.0, margin: float = 0.5): |
|
|
super().__init__() |
|
|
self.scale = scale |
|
|
self.margin = margin |
|
|
self.cos_m = math.cos(margin) |
|
|
self.sin_m = math.sin(margin) |
|
|
self.th = math.cos(math.pi - margin) |
|
|
self.mm = math.sin(math.pi - margin) * margin |
|
|
|
|
|
def forward(self, cosine: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: |
|
|
sine = torch.sqrt(1.0 - torch.pow(cosine, 2)) |
|
|
phi = cosine * self.cos_m - sine * self.sin_m |
|
|
phi = torch.where(cosine > self.th, phi, cosine - self.mm) |
|
|
|
|
|
one_hot = torch.zeros_like(cosine) |
|
|
one_hot.scatter_(1, labels.view(-1, 1), 1) |
|
|
|
|
|
output = (one_hot * phi) + ((1.0 - one_hot) * cosine) |
|
|
output *= self.scale |
|
|
|
|
|
return F.cross_entropy(output, labels) |
|
|
|
|
|
|
|
|
class MultiSimilarityLoss(nn.Module): |
|
|
"""Multi-Similarity Loss for hard sample mining""" |
|
|
|
|
|
def __init__(self, alpha: float = 2.0, beta: float = 50.0, base: float = 0.5): |
|
|
super().__init__() |
|
|
self.alpha = alpha |
|
|
self.beta = beta |
|
|
self.base = base |
|
|
self.margin = 0.1 |
|
|
|
|
|
def forward(self, embeddings: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: |
|
|
sim_mat = torch.matmul(embeddings, embeddings.t()) |
|
|
|
|
|
labels = labels.view(-1, 1) |
|
|
pos_mask = (labels == labels.t()).float() |
|
|
neg_mask = (labels != labels.t()).float() |
|
|
pos_mask.fill_diagonal_(0) |
|
|
|
|
|
loss = 0.0 |
|
|
num_valid = 0 |
|
|
|
|
|
for i in range(embeddings.size(0)): |
|
|
pos_pair = sim_mat[i][pos_mask[i] == 1] |
|
|
neg_pair = sim_mat[i][neg_mask[i] == 1] |
|
|
|
|
|
if len(pos_pair) == 0 or len(neg_pair) == 0: |
|
|
continue |
|
|
|
|
|
|
|
|
pos_min = pos_pair.min() |
|
|
neg_max = neg_pair.max() |
|
|
|
|
|
hard_neg = neg_pair[neg_pair + self.margin > pos_min] |
|
|
hard_pos = pos_pair[pos_pair - self.margin < neg_max] |
|
|
|
|
|
if len(hard_pos) == 0 or len(hard_neg) == 0: |
|
|
continue |
|
|
|
|
|
pos_loss = (1.0 / self.alpha) * torch.log( |
|
|
1 + torch.sum(torch.exp(-self.alpha * (hard_pos - self.base))) |
|
|
) |
|
|
neg_loss = (1.0 / self.beta) * torch.log( |
|
|
1 + torch.sum(torch.exp(self.beta * (hard_neg - self.base))) |
|
|
) |
|
|
|
|
|
loss += pos_loss + neg_loss |
|
|
num_valid += 1 |
|
|
|
|
|
if num_valid == 0: |
|
|
return torch.tensor(0.0, device=embeddings.device, requires_grad=True) |
|
|
|
|
|
return loss / num_valid |
|
|
|
|
|
|
|
|
class CenterLoss(nn.Module): |
|
|
"""Center Loss for intra-class compactness""" |
|
|
|
|
|
def __init__(self, num_classes: int, feat_dim: int): |
|
|
super().__init__() |
|
|
self.centers = nn.Parameter(torch.randn(num_classes, feat_dim)) |
|
|
nn.init.xavier_uniform_(self.centers) |
|
|
|
|
|
def forward(self, embeddings: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: |
|
|
centers_batch = self.centers[labels] |
|
|
return torch.pow(embeddings - centers_batch, 2).sum(dim=1).mean() |
|
|
|
|
|
|
|
|
class CombinedMetricLoss(nn.Module): |
|
|
"""Combined loss: ArcFace + Multi-Similarity + Center""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
num_classes: int, |
|
|
embedding_dim: int = 512, |
|
|
arcface_scale: float = 64.0, |
|
|
arcface_margin: float = 0.5, |
|
|
arcface_weight: float = 0.2, |
|
|
ms_weight: float = 3.0, |
|
|
center_weight: float = 0.01, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.arcface = ArcFaceLoss(scale=arcface_scale, margin=arcface_margin) |
|
|
self.ms_loss = MultiSimilarityLoss() |
|
|
self.center_loss = CenterLoss(num_classes=num_classes, feat_dim=embedding_dim) |
|
|
|
|
|
self.arcface_weight = arcface_weight |
|
|
self.ms_weight = ms_weight |
|
|
self.center_weight = center_weight |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
embeddings: torch.Tensor, |
|
|
cosine: torch.Tensor, |
|
|
labels: torch.Tensor, |
|
|
) -> Tuple[torch.Tensor, dict]: |
|
|
|
|
|
loss_arc = self.arcface(cosine, labels) |
|
|
loss_ms = self.ms_loss(embeddings, labels) |
|
|
loss_center = self.center_loss(embeddings, labels) |
|
|
|
|
|
total = ( |
|
|
self.arcface_weight * loss_arc + |
|
|
self.ms_weight * loss_ms + |
|
|
self.center_weight * loss_center |
|
|
) |
|
|
|
|
|
return total, { |
|
|
'loss_total': total.item(), |
|
|
'loss_arcface': loss_arc.item(), |
|
|
'loss_ms': loss_ms.item(), |
|
|
'loss_center': loss_center.item(), |
|
|
} |
|
|
|
|
|
|
|
|
def create_loss(config, num_classes: int) -> CombinedMetricLoss: |
|
|
return CombinedMetricLoss( |
|
|
num_classes=num_classes, |
|
|
embedding_dim=config.model.embedding_dim, |
|
|
arcface_scale=config.loss.arcface_scale, |
|
|
arcface_margin=config.loss.arcface_margin, |
|
|
arcface_weight=config.loss.arcface_weight, |
|
|
ms_weight=config.loss.ms_loss_weight, |
|
|
center_weight=config.loss.center_loss_weight, |
|
|
) |
|
|
|