from __future__ import print_function import torch import torch.nn as nn from .model import* import torch.nn.functional as F def vq_cos_sim(embedding, input_tensor, use_dynamic_p=False,ddp=False): if ddp: embedding_weight = embedding.module.weight else: embedding_weight = embedding.weight input_norm = F.normalize(input_tensor, p=2, dim=2) embedding_norm = F.normalize(embedding_weight, p=2, dim=1) similarity = torch.matmul(input_norm, embedding_norm.t()) cos_sim_values, indices = similarity.max(dim=2) if use_dynamic_p: return indices.squeeze(), cos_sim_values.squeeze() return indices.squeeze() class RatioLossWithMSELoss(nn.Module): def __init__(self, total_iters, min_weight=0.001, max_weight=1,eps=torch.tensor(1e-3, dtype=torch.bfloat16)): super(RatioLossWithMSELoss, self).__init__() self.eps = eps self.total_iters = total_iters self.min_weight = min_weight self.max_weight = max_weight self.mse=nn.MSELoss() self.losses={} def forward(self, output, target, current_iter): weight = self.min_weight + (self.max_weight - self.min_weight) * (current_iter / self.total_iters) loss = (torch.abs(target - output)) / (torch.abs(target) + self.eps) weighted_loss = weight * loss self.losses['ratio']=loss.mean() self.losses['mse']=self.mse(output,target) return weighted_loss.mean()+self.mse(output,target) if __name__=='__main__': pass