import torch.nn.functional as F class TripletLoss: """Triplet Margin Loss""" def __init__(self, margin=1.0): self.margin = margin def __call__(self, anchor, positive, negative): return F.triplet_margin_loss(anchor, positive, negative, margin=self.margin, reduction='mean')