iljung1106
Initial commit
546ff88
"""
Artist Style Embedding - Loss Functions
ArcFace + Multi-Similarity + Center Loss
"""
import math
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
class ArcFaceLoss(nn.Module):
"""ArcFace Loss (Additive Angular Margin Loss)"""
def __init__(self, scale: float = 64.0, margin: float = 0.5):
super().__init__()
self.scale = scale
self.margin = margin
self.cos_m = math.cos(margin)
self.sin_m = math.sin(margin)
self.th = math.cos(math.pi - margin)
self.mm = math.sin(math.pi - margin) * margin
def forward(self, cosine: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
phi = cosine * self.cos_m - sine * self.sin_m
phi = torch.where(cosine > self.th, phi, cosine - self.mm)
one_hot = torch.zeros_like(cosine)
one_hot.scatter_(1, labels.view(-1, 1), 1)
output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
output *= self.scale
return F.cross_entropy(output, labels)
class MultiSimilarityLoss(nn.Module):
"""Multi-Similarity Loss for hard sample mining"""
def __init__(self, alpha: float = 2.0, beta: float = 50.0, base: float = 0.5):
super().__init__()
self.alpha = alpha
self.beta = beta
self.base = base
self.margin = 0.1
def forward(self, embeddings: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
sim_mat = torch.matmul(embeddings, embeddings.t())
labels = labels.view(-1, 1)
pos_mask = (labels == labels.t()).float()
neg_mask = (labels != labels.t()).float()
pos_mask.fill_diagonal_(0)
loss = 0.0
num_valid = 0
for i in range(embeddings.size(0)):
pos_pair = sim_mat[i][pos_mask[i] == 1]
neg_pair = sim_mat[i][neg_mask[i] == 1]
if len(pos_pair) == 0 or len(neg_pair) == 0:
continue
# Hard mining with safety checks
pos_min = pos_pair.min()
neg_max = neg_pair.max()
hard_neg = neg_pair[neg_pair + self.margin > pos_min]
hard_pos = pos_pair[pos_pair - self.margin < neg_max]
if len(hard_pos) == 0 or len(hard_neg) == 0:
continue
pos_loss = (1.0 / self.alpha) * torch.log(
1 + torch.sum(torch.exp(-self.alpha * (hard_pos - self.base)))
)
neg_loss = (1.0 / self.beta) * torch.log(
1 + torch.sum(torch.exp(self.beta * (hard_neg - self.base)))
)
loss += pos_loss + neg_loss
num_valid += 1
if num_valid == 0:
return torch.tensor(0.0, device=embeddings.device, requires_grad=True)
return loss / num_valid
class CenterLoss(nn.Module):
"""Center Loss for intra-class compactness"""
def __init__(self, num_classes: int, feat_dim: int):
super().__init__()
self.centers = nn.Parameter(torch.randn(num_classes, feat_dim))
nn.init.xavier_uniform_(self.centers)
def forward(self, embeddings: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
centers_batch = self.centers[labels]
return torch.pow(embeddings - centers_batch, 2).sum(dim=1).mean()
class CombinedMetricLoss(nn.Module):
"""Combined loss: ArcFace + Multi-Similarity + Center"""
def __init__(
self,
num_classes: int,
embedding_dim: int = 512,
arcface_scale: float = 64.0,
arcface_margin: float = 0.5,
arcface_weight: float = 0.2,
ms_weight: float = 3.0,
center_weight: float = 0.01,
):
super().__init__()
self.arcface = ArcFaceLoss(scale=arcface_scale, margin=arcface_margin)
self.ms_loss = MultiSimilarityLoss()
self.center_loss = CenterLoss(num_classes=num_classes, feat_dim=embedding_dim)
self.arcface_weight = arcface_weight
self.ms_weight = ms_weight
self.center_weight = center_weight
def forward(
self,
embeddings: torch.Tensor,
cosine: torch.Tensor,
labels: torch.Tensor,
) -> Tuple[torch.Tensor, dict]:
loss_arc = self.arcface(cosine, labels)
loss_ms = self.ms_loss(embeddings, labels)
loss_center = self.center_loss(embeddings, labels)
total = (
self.arcface_weight * loss_arc +
self.ms_weight * loss_ms +
self.center_weight * loss_center
)
return total, {
'loss_total': total.item(),
'loss_arcface': loss_arc.item(),
'loss_ms': loss_ms.item(),
'loss_center': loss_center.item(),
}
def create_loss(config, num_classes: int) -> CombinedMetricLoss:
return CombinedMetricLoss(
num_classes=num_classes,
embedding_dim=config.model.embedding_dim,
arcface_scale=config.loss.arcface_scale,
arcface_margin=config.loss.arcface_margin,
arcface_weight=config.loss.arcface_weight,
ms_weight=config.loss.ms_loss_weight,
center_weight=config.loss.center_loss_weight,
)