LETTER / src /loss.py
Sharath33's picture
Upload folder using huggingface_hub
02ac88d verified
# src/loss.py
import torch
import torch.nn as nn
import torch.nn.functional as F
class ContrastiveLoss(nn.Module):
def __init__(self, margin=1.0):
super().__init__()
self.margin = margin
def forward(self, emb1, emb2, label):
# Euclidean distance between embedding pairs
dist = F.pairwise_distance(emb1, emb2)
# label=1 → same class (pull together), label=0 → different class (push apart)
loss = label * dist.pow(2) + \
(1 - label) * F.relu(self.margin - dist).pow(2)
return loss.mean(), dist