from pytorch_metric_learning.distances import CosineSimilarity import torch class InfoNCELoss(): def __init__(self, device, k, temperature=0.07, threshold=1.0, fna=False): super(InfoNCELoss, self).__init__() self.device = device self.similarity = CosineSimilarity() self.k = k self.temperature = temperature self.threshold = threshold self.fna = fna def __call__(self, x, y, labels, sts): false_negatives = self.detect_false_negative(sts) indices_tuple = self.get_all_pairs_indices(labels, false_negatives) mat = self.similarity(x, y) a1, p, a2, n = indices_tuple pos_pair, neg_pair = [], [] if len(a1) > 0: pos_pair = mat[a1, p] if len(a2) > 0: neg_pair = mat[a2, n] if len(neg_pair) > 0 and self.k > -1: paired = list(zip(neg_pair.tolist(), a2.tolist(), n.tolist())) selected = sorted(paired, key=lambda x: x[0], reverse=True)[:self.k] _, x, y = list(zip(*selected)) x = torch.tensor(x).to(a2.device) y = torch.tensor(y).to(n.device) neg_pair = mat[x, y] indices_tuple = (a1, p, x, y) return self._compute_loss(pos_pair, neg_pair, indices_tuple), len(pos_pair) def detect_false_negative(self, embs): mat = torch.matmul(embs, torch.t(embs)) return torch.where(mat >= self.threshold) def _compute_loss(self, pos_pairs, neg_pairs, indices_tuple): a1, p, a2, _ = indices_tuple if len(a1) > 0 and len(a2) > 0: dtype = neg_pairs.dtype if not self.similarity.is_inverted: pos_pairs = -pos_pairs neg_pairs = -neg_pairs pos_pairs = pos_pairs.unsqueeze(1) / self.temperature neg_pairs = neg_pairs / self.temperature n_per_p = a2.unsqueeze(0) == a1.unsqueeze(1) neg_pairs = neg_pairs * n_per_p neg_pairs[n_per_p == 0] = torch.finfo(dtype).min max_val = torch.max( pos_pairs, torch.max(neg_pairs, dim=1, keepdim=True)[0] ).detach() numerator = torch.exp(pos_pairs - max_val).squeeze(1) denominator = torch.sum(torch.exp(neg_pairs - max_val), dim=1) + numerator log_exp = torch.log((numerator / denominator) + torch.finfo(dtype).tiny) return torch.mean(-log_exp) return 0 def get_all_pairs_indices(self, labels, false_negatives): labels1 = labels.unsqueeze(1) labels2 = labels.unsqueeze(0) matches = (labels1 == labels2).byte() diffs = matches ^ 1 diffs[false_negatives[0], false_negatives[1]] = 0 # FNE if self.fna: matches[false_negatives[0], false_negatives[1]] = 1 # FNA diffs.fill_diagonal_(0) matches.fill_diagonal_(1) a1_idx, p_idx = torch.where(matches) a2_idx, n_idx = torch.where(diffs) return a1_idx, p_idx, a2_idx, n_idx