import torch class ContrastiveLoss(torch.nn.Module): """ Contrastive loss Adapted from: (OnlineContrastiveLoss) https://github.com/adambielski/siamese-triplet/blob/master/losses.py Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf """ def __init__(self, margin): super(ContrastiveLoss, self).__init__() self.margin = margin def forward(self, anchors, negatives, positives): anchors = anchors / anchors.norm(dim=-1, keepdim=True) negatives = negatives / negatives.norm(dim=-1, keepdim=True) positives = positives / positives.norm(dim=-1, keepdim=True) positive_loss = (anchors - positives).pow(2).sum(1) negative_loss = torch.nn.functional.relu( self.margin - (anchors - negatives).pow(2).sum(1).sqrt()).pow(2) loss = 0.5 * torch.cat([positive_loss, negative_loss], dim=0) return loss.mean()