File size: 578 Bytes
02ac88d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# src/loss.py
import torch
import torch.nn as nn
import torch.nn.functional as F

class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0):
        super().__init__()
        self.margin = margin

    def forward(self, emb1, emb2, label):
        # Euclidean distance between embedding pairs
        dist = F.pairwise_distance(emb1, emb2)

        # label=1 → same class (pull together), label=0 → different class (push apart)
        loss = label * dist.pow(2) + \
               (1 - label) * F.relu(self.margin - dist).pow(2)

        return loss.mean(), dist