# src/loss.py import torch import torch.nn as nn import torch.nn.functional as F class ContrastiveLoss(nn.Module): def __init__(self, margin=1.0): super().__init__() self.margin = margin def forward(self, emb1, emb2, label): # Euclidean distance between embedding pairs dist = F.pairwise_distance(emb1, emb2) # label=1 → same class (pull together), label=0 → different class (push apart) loss = label * dist.pow(2) + \ (1 - label) * F.relu(self.margin - dist).pow(2) return loss.mean(), dist