CalliReader / models /similarity.py
gtang666's picture
Codes
ed1622f verified
from __future__ import print_function
import torch
import torch.nn as nn
from .model import*
import torch.nn.functional as F
def vq_cos_sim(embedding, input_tensor, use_dynamic_p=False,ddp=False):
if ddp:
embedding_weight = embedding.module.weight
else:
embedding_weight = embedding.weight
input_norm = F.normalize(input_tensor, p=2, dim=2)
embedding_norm = F.normalize(embedding_weight, p=2, dim=1)
similarity = torch.matmul(input_norm, embedding_norm.t())
cos_sim_values, indices = similarity.max(dim=2)
if use_dynamic_p:
return indices.squeeze(), cos_sim_values.squeeze()
return indices.squeeze()
class RatioLossWithMSELoss(nn.Module):
def __init__(self, total_iters, min_weight=0.001, max_weight=1,eps=torch.tensor(1e-3, dtype=torch.bfloat16)):
super(RatioLossWithMSELoss, self).__init__()
self.eps = eps
self.total_iters = total_iters
self.min_weight = min_weight
self.max_weight = max_weight
self.mse=nn.MSELoss()
self.losses={}
def forward(self, output, target, current_iter):
weight = self.min_weight + (self.max_weight - self.min_weight) * (current_iter / self.total_iters)
loss = (torch.abs(target - output)) / (torch.abs(target) + self.eps)
weighted_loss = weight * loss
self.losses['ratio']=loss.mean()
self.losses['mse']=self.mse(output,target)
return weighted_loss.mean()+self.mse(output,target)
if __name__=='__main__':
pass