File size: 3,042 Bytes
f6ffda2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from pytorch_metric_learning.distances import CosineSimilarity
import torch


class InfoNCELoss():
    def __init__(self, device, k, temperature=0.07, threshold=1.0, fna=False):
        super(InfoNCELoss, self).__init__()
        self.device = device
        self.similarity = CosineSimilarity()
        self.k = k
        self.temperature = temperature
        self.threshold = threshold
        self.fna = fna

    def __call__(self, x, y, labels, sts):
        false_negatives = self.detect_false_negative(sts)
        indices_tuple = self.get_all_pairs_indices(labels, false_negatives)

        mat = self.similarity(x, y)
        a1, p, a2, n = indices_tuple
        pos_pair, neg_pair = [], []
        if len(a1) > 0:
            pos_pair = mat[a1, p]
        if len(a2) > 0:
            neg_pair = mat[a2, n]

        if len(neg_pair) > 0 and self.k > -1:
            paired = list(zip(neg_pair.tolist(), a2.tolist(), n.tolist()))
            selected = sorted(paired, key=lambda x: x[0], reverse=True)[:self.k]
            _, x, y = list(zip(*selected))
            x = torch.tensor(x).to(a2.device)
            y = torch.tensor(y).to(n.device)

            neg_pair = mat[x, y]
            indices_tuple = (a1, p, x, y)

        return self._compute_loss(pos_pair, neg_pair, indices_tuple), len(pos_pair)

    def detect_false_negative(self, embs):
        mat = torch.matmul(embs, torch.t(embs))
        return torch.where(mat >= self.threshold)

    def _compute_loss(self, pos_pairs, neg_pairs, indices_tuple):
        a1, p, a2, _ = indices_tuple

        if len(a1) > 0 and len(a2) > 0:
            dtype = neg_pairs.dtype

            if not self.similarity.is_inverted:
                pos_pairs = -pos_pairs
                neg_pairs = -neg_pairs

            pos_pairs = pos_pairs.unsqueeze(1) / self.temperature
            neg_pairs = neg_pairs / self.temperature
            n_per_p = a2.unsqueeze(0) == a1.unsqueeze(1)
            neg_pairs = neg_pairs * n_per_p
            neg_pairs[n_per_p == 0] = torch.finfo(dtype).min

            max_val = torch.max(
                pos_pairs, torch.max(neg_pairs, dim=1, keepdim=True)[0]
            ).detach()
            numerator = torch.exp(pos_pairs - max_val).squeeze(1)
            denominator = torch.sum(torch.exp(neg_pairs - max_val), dim=1) + numerator
            log_exp = torch.log((numerator / denominator) + torch.finfo(dtype).tiny)
            return torch.mean(-log_exp)

        return 0

    def get_all_pairs_indices(self, labels, false_negatives):
        labels1 = labels.unsqueeze(1)
        labels2 = labels.unsqueeze(0)
        matches = (labels1 == labels2).byte()
        diffs = matches ^ 1

        diffs[false_negatives[0], false_negatives[1]] = 0  # FNE
        if self.fna:
            matches[false_negatives[0], false_negatives[1]] = 1  # FNA

        diffs.fill_diagonal_(0)
        matches.fill_diagonal_(1)

        a1_idx, p_idx = torch.where(matches)
        a2_idx, n_idx = torch.where(diffs)
        return a1_idx, p_idx, a2_idx, n_idx