File size: 1,556 Bytes
ed1622f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
from __future__ import print_function
import torch
import torch.nn as nn
from .model import*
import torch.nn.functional as F
def vq_cos_sim(embedding, input_tensor, use_dynamic_p=False,ddp=False):
if ddp:
embedding_weight = embedding.module.weight
else:
embedding_weight = embedding.weight
input_norm = F.normalize(input_tensor, p=2, dim=2)
embedding_norm = F.normalize(embedding_weight, p=2, dim=1)
similarity = torch.matmul(input_norm, embedding_norm.t())
cos_sim_values, indices = similarity.max(dim=2)
if use_dynamic_p:
return indices.squeeze(), cos_sim_values.squeeze()
return indices.squeeze()
class RatioLossWithMSELoss(nn.Module):
def __init__(self, total_iters, min_weight=0.001, max_weight=1,eps=torch.tensor(1e-3, dtype=torch.bfloat16)):
super(RatioLossWithMSELoss, self).__init__()
self.eps = eps
self.total_iters = total_iters
self.min_weight = min_weight
self.max_weight = max_weight
self.mse=nn.MSELoss()
self.losses={}
def forward(self, output, target, current_iter):
weight = self.min_weight + (self.max_weight - self.min_weight) * (current_iter / self.total_iters)
loss = (torch.abs(target - output)) / (torch.abs(target) + self.eps)
weighted_loss = weight * loss
self.losses['ratio']=loss.mean()
self.losses['mse']=self.mse(output,target)
return weighted_loss.mean()+self.mse(output,target)
if __name__=='__main__':
pass |