| from torch import nn |
| from abc import abstractmethod |
|
|
| import torch |
| from torch import nn |
| from .cross_entropy_model import FBankNetV2 |
|
|
| class TripletLoss(nn.Module): |
|
|
| def __init__(self, margin): |
| super().__init__() |
| self.cosine_similarity = nn.CosineSimilarity() |
| self.margin = margin |
|
|
| def forward(self, anchor_embeddings, positive_embeddings, negative_embeddings, reduction='mean'): |
|
|
| |
| |
|
|
| positive_distance = 1 - self.cosine_similarity(anchor_embeddings, positive_embeddings) |
| negative_distance = 1 - self.cosine_similarity(anchor_embeddings, negative_embeddings) |
|
|
| losses = torch.max(positive_distance - negative_distance + self.margin,torch.full_like(positive_distance, 0)) |
| if reduction == 'mean': |
| return torch.mean(losses) |
| else: |
| return torch.sum(losses) |
|
|
|
|
| class FBankTripletLossNet(FBankNetV2): |
|
|
| def __init__(self,num_layers, margin): |
| super().__init__(num_layers=num_layers) |
| self.loss_layer = TripletLoss(margin) |
|
|
| def forward(self, anchor, positive, negative): |
| n = anchor.shape[0] |
| anchor_out = self.network(anchor) |
| anchor_out = anchor_out.reshape(n, -1) |
| anchor_out = self.linear_layer(anchor_out) |
|
|
| positive_out = self.network(positive) |
| positive_out = positive_out.reshape(n, -1) |
| positive_out = self.linear_layer(positive_out) |
|
|
| negative_out = self.network(negative) |
| negative_out = negative_out.reshape(n, -1) |
| negative_out = self.linear_layer(negative_out) |
|
|
| return anchor_out, positive_out, negative_out |
|
|
| def loss(self, anchor, positive, negative, reduction='mean'): |
| loss_val = self.loss_layer(anchor, positive, negative, reduction) |
| return loss_val |