File size: 303 Bytes
4c1e73e |
1 2 3 4 5 6 7 8 9 10 |
import torch.nn.functional as F
class TripletLoss:
"""Triplet Margin Loss"""
def __init__(self, margin=1.0):
self.margin = margin
def __call__(self, anchor, positive, negative):
return F.triplet_margin_loss(anchor, positive, negative, margin=self.margin, reduction='mean')
|