| # src/loss.py | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class ContrastiveLoss(nn.Module): | |
| def __init__(self, margin=1.0): | |
| super().__init__() | |
| self.margin = margin | |
| def forward(self, emb1, emb2, label): | |
| # Euclidean distance between embedding pairs | |
| dist = F.pairwise_distance(emb1, emb2) | |
| # label=1 → same class (pull together), label=0 → different class (push apart) | |
| loss = label * dist.pow(2) + \ | |
| (1 - label) * F.relu(self.margin - dist).pow(2) | |
| return loss.mean(), dist |