|
|
from pytorch_metric_learning.distances import CosineSimilarity |
|
|
import torch |
|
|
|
|
|
|
|
|
class InfoNCELoss(): |
|
|
def __init__(self, device, k, temperature=0.07, threshold=1.0, fna=False): |
|
|
super(InfoNCELoss, self).__init__() |
|
|
self.device = device |
|
|
self.similarity = CosineSimilarity() |
|
|
self.k = k |
|
|
self.temperature = temperature |
|
|
self.threshold = threshold |
|
|
self.fna = fna |
|
|
|
|
|
def __call__(self, x, y, labels, sts): |
|
|
false_negatives = self.detect_false_negative(sts) |
|
|
indices_tuple = self.get_all_pairs_indices(labels, false_negatives) |
|
|
|
|
|
mat = self.similarity(x, y) |
|
|
a1, p, a2, n = indices_tuple |
|
|
pos_pair, neg_pair = [], [] |
|
|
if len(a1) > 0: |
|
|
pos_pair = mat[a1, p] |
|
|
if len(a2) > 0: |
|
|
neg_pair = mat[a2, n] |
|
|
|
|
|
if len(neg_pair) > 0 and self.k > -1: |
|
|
paired = list(zip(neg_pair.tolist(), a2.tolist(), n.tolist())) |
|
|
selected = sorted(paired, key=lambda x: x[0], reverse=True)[:self.k] |
|
|
_, x, y = list(zip(*selected)) |
|
|
x = torch.tensor(x).to(a2.device) |
|
|
y = torch.tensor(y).to(n.device) |
|
|
|
|
|
neg_pair = mat[x, y] |
|
|
indices_tuple = (a1, p, x, y) |
|
|
|
|
|
return self._compute_loss(pos_pair, neg_pair, indices_tuple), len(pos_pair) |
|
|
|
|
|
def detect_false_negative(self, embs): |
|
|
mat = torch.matmul(embs, torch.t(embs)) |
|
|
return torch.where(mat >= self.threshold) |
|
|
|
|
|
def _compute_loss(self, pos_pairs, neg_pairs, indices_tuple): |
|
|
a1, p, a2, _ = indices_tuple |
|
|
|
|
|
if len(a1) > 0 and len(a2) > 0: |
|
|
dtype = neg_pairs.dtype |
|
|
|
|
|
if not self.similarity.is_inverted: |
|
|
pos_pairs = -pos_pairs |
|
|
neg_pairs = -neg_pairs |
|
|
|
|
|
pos_pairs = pos_pairs.unsqueeze(1) / self.temperature |
|
|
neg_pairs = neg_pairs / self.temperature |
|
|
n_per_p = a2.unsqueeze(0) == a1.unsqueeze(1) |
|
|
neg_pairs = neg_pairs * n_per_p |
|
|
neg_pairs[n_per_p == 0] = torch.finfo(dtype).min |
|
|
|
|
|
max_val = torch.max( |
|
|
pos_pairs, torch.max(neg_pairs, dim=1, keepdim=True)[0] |
|
|
).detach() |
|
|
numerator = torch.exp(pos_pairs - max_val).squeeze(1) |
|
|
denominator = torch.sum(torch.exp(neg_pairs - max_val), dim=1) + numerator |
|
|
log_exp = torch.log((numerator / denominator) + torch.finfo(dtype).tiny) |
|
|
return torch.mean(-log_exp) |
|
|
|
|
|
return 0 |
|
|
|
|
|
def get_all_pairs_indices(self, labels, false_negatives): |
|
|
labels1 = labels.unsqueeze(1) |
|
|
labels2 = labels.unsqueeze(0) |
|
|
matches = (labels1 == labels2).byte() |
|
|
diffs = matches ^ 1 |
|
|
|
|
|
diffs[false_negatives[0], false_negatives[1]] = 0 |
|
|
if self.fna: |
|
|
matches[false_negatives[0], false_negatives[1]] = 1 |
|
|
|
|
|
diffs.fill_diagonal_(0) |
|
|
matches.fill_diagonal_(1) |
|
|
|
|
|
a1_idx, p_idx = torch.where(matches) |
|
|
a2_idx, n_idx = torch.where(diffs) |
|
|
return a1_idx, p_idx, a2_idx, n_idx |
|
|
|