import torch import torch.nn as nn import torch.nn.functional as F def contrastive_loss(v1, v2, tau=1.0) -> torch.Tensor: v1_norm = torch.norm(v1, dim=1, keepdim=True) v2_norm = torch.norm(v2, dim=1, keepdim=True) v2T = torch.transpose(v2, 0, 1) inner_prod = torch.matmul(v1, v2T) v2_normT = torch.transpose(v2_norm, 0, 1) norm_mat = torch.matmul(v1_norm, v2_normT) loss_mat = torch.div(inner_prod, norm_mat) loss_mat = loss_mat * (1/tau) loss_mat = torch.exp(loss_mat) numerator = torch.diagonal(loss_mat) numerator = torch.unsqueeze(numerator, 0) Lv1_v2_denom = torch.sum(loss_mat, dim=1, keepdim=True) Lv1_v2_denom = torch.transpose(Lv1_v2_denom, 0, 1) #Lv1_v2_denom = Lv1_v2_denom - numerator Lv2_v1_denom = torch.sum(loss_mat, dim=0, keepdim=True) #Lv2_v1_denom = Lv2_v1_denom - numerator Lv1_v2 = torch.div(numerator, Lv1_v2_denom) Lv1_v2 = -1 * torch.log(Lv1_v2) Lv1_v2 = torch.mean(Lv1_v2) Lv2_v1 = torch.div(numerator, Lv2_v1_denom) Lv2_v1 = -1 * torch.log(Lv2_v1) Lv2_v1 = torch.mean(Lv2_v1) return Lv1_v2 + Lv2_v1 , torch.mean(numerator), torch.mean(Lv1_v2_denom+Lv2_v1_denom) def cand_spec_sim_loss(spec_enc, cand_enc): cand_enc = torch.transpose(cand_enc, 0, 1) # C x B x d spec_enc = spec_enc.unsqueeze(0) # 1 x B x d sim = nn.functional.cosine_similarity(spec_enc, cand_enc, dim=2) loss = torch.mean(sim) return loss class cons_spec_loss: def __init__(self, loss_type) -> None: self.loss_compute = {'cosine': self.cos_loss, 'l2':torch.nn.MSELoss()}[loss_type] def __call__(self,cons_spec, ind_spec): return self.loss_compute(cons_spec, ind_spec) def cos_loss(self, cons_spec, ind_spec): sim = nn.functional.cosine_similarity(cons_spec, ind_spec) loss = 1-torch.mean(sim) return loss class fp_loss: def __init__(self, loss_type) -> None: self.loss_compute = {'cosine': self.fp_loss_cos, 'bce': nn.BCELoss()}[loss_type] def __call__(self, predicted_fp, target_fp): return self.loss_compute(predicted_fp, target_fp) def fp_loss_cos(self, predicted_fp, target_fp): sim = nn.functional.cosine_similarity(predicted_fp, target_fp) return 1 - torch.mean(sim)