Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from torch.distributions.kl import kl_divergence | |
| from torch.distributions.normal import Normal | |
| from torch.nn.functional import relu | |
| class BatchHardTripletLoss(nn.Module): | |
| def __init__(self, margin=1., squared=False, agg='sum'): | |
| """ | |
| Initalize the loss function with a margin parameter, whether or not to consider | |
| squared Euclidean distance and how to aggregate the loss in a batch | |
| """ | |
| super().__init__() | |
| self.margin = margin | |
| self.squared = squared | |
| self.agg = agg | |
| self.eps = 1e-8 | |
| def get_pairwise_distances(self, embeddings): | |
| """ | |
| Computing Euclidean distance for all possible pairs of embeddings. | |
| """ | |
| ab = embeddings.mm(embeddings.t()) | |
| a_squared = ab.diag().unsqueeze(1) | |
| b_squared = ab.diag().unsqueeze(0) | |
| distances = a_squared - 2 * ab + b_squared | |
| distances = relu(distances) | |
| if not self.squared: | |
| distances = torch.sqrt(distances + self.eps) | |
| return distances | |
| def hardest_triplet_mining(self, dist_mat, labels): | |
| assert len(dist_mat.size()) == 2 | |
| assert dist_mat.size(0) == dist_mat.size(1) | |
| N = dist_mat.size(0) | |
| is_pos = labels.expand(N, N).eq(labels.expand(N, N).t()) | |
| is_neg = labels.expand(N, N).ne(labels.expand(N, N).t()) | |
| dist_ap, relative_p_inds = torch.max( | |
| (dist_mat * is_pos), 1, keepdim=True) | |
| dist_an, relative_n_inds = torch.min( | |
| (dist_mat * is_neg), 1, keepdim=True) | |
| return dist_ap, dist_an | |
| def forward(self, embeddings, labels): | |
| distances = self.get_pairwise_distances(embeddings) | |
| dist_ap, dist_an = self.hardest_triplet_mining(distances, labels) | |
| triplet_loss = relu(dist_ap - dist_an + self.margin).sum() | |
| return triplet_loss | |
| class VAELoss(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.reconstruction_loss = nn.BCELoss(reduction='sum') | |
| def kl_divergence_loss(self, q_dist): | |
| return kl_divergence( | |
| q_dist, Normal(torch.zeros_like(q_dist.mean), torch.ones_like(q_dist.stddev)) | |
| ).sum(-1) | |
| def forward(self, output, target, encoding): | |
| loss = self.kl_divergence_loss(encoding).sum() + self.reconstruction_loss(output, target) | |
| return loss | |