File size: 2,730 Bytes
d9df210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
import torch.nn as nn
import torch.nn.functional as F

def contrastive_loss(v1, v2, tau=1.0) -> torch.Tensor:
        v1_norm = torch.norm(v1, dim=1, keepdim=True)
        v2_norm = torch.norm(v2, dim=1, keepdim=True)
        
        v2T = torch.transpose(v2, 0, 1)
        
        inner_prod = torch.matmul(v1, v2T)
        
        v2_normT = torch.transpose(v2_norm, 0, 1)
        
        norm_mat = torch.matmul(v1_norm, v2_normT)
        
        loss_mat = torch.div(inner_prod, norm_mat)
        
        loss_mat = loss_mat * (1/tau)
        
        loss_mat = torch.exp(loss_mat)
        
        numerator = torch.diagonal(loss_mat)
        numerator = torch.unsqueeze(numerator, 0)
        
        Lv1_v2_denom = torch.sum(loss_mat, dim=1, keepdim=True)
        Lv1_v2_denom = torch.transpose(Lv1_v2_denom, 0, 1)
        #Lv1_v2_denom = Lv1_v2_denom - numerator
        
        Lv2_v1_denom = torch.sum(loss_mat, dim=0, keepdim=True)
        #Lv2_v1_denom = Lv2_v1_denom - numerator
        
        Lv1_v2 = torch.div(numerator, Lv1_v2_denom)
        
        Lv1_v2 = -1 * torch.log(Lv1_v2)
        Lv1_v2 = torch.mean(Lv1_v2)
        
        Lv2_v1 = torch.div(numerator, Lv2_v1_denom)
        
        Lv2_v1 = -1 * torch.log(Lv2_v1)
        Lv2_v1 = torch.mean(Lv2_v1)
        
        return Lv1_v2 + Lv2_v1 , torch.mean(numerator), torch.mean(Lv1_v2_denom+Lv2_v1_denom)

def cand_spec_sim_loss(spec_enc, cand_enc):
        cand_enc = torch.transpose(cand_enc, 0, 1) # C x B x d
        spec_enc = spec_enc.unsqueeze(0) # 1 x B x d

        sim = nn.functional.cosine_similarity(spec_enc, cand_enc, dim=2)
        loss = torch.mean(sim)

        return loss

class cons_spec_loss:
        def __init__(self, loss_type) -> None:
                self.loss_compute = {'cosine': self.cos_loss,
                                     'l2':torch.nn.MSELoss()}[loss_type]
        def __call__(self,cons_spec, ind_spec):
                return self.loss_compute(cons_spec, ind_spec)
        
        def cos_loss(self, cons_spec, ind_spec):
                sim = nn.functional.cosine_similarity(cons_spec, ind_spec)
                loss = 1-torch.mean(sim) 
                return loss

class fp_loss:
        def __init__(self, loss_type) -> None:
                self.loss_compute = {'cosine': self.fp_loss_cos,
                                        'bce': nn.BCELoss()}[loss_type]
        
        def __call__(self, predicted_fp, target_fp):
                return self.loss_compute(predicted_fp, target_fp)
        
        def fp_loss_cos(self, predicted_fp, target_fp):
                sim = nn.functional.cosine_similarity(predicted_fp, target_fp)
                return 1 - torch.mean(sim)