|
|
|
|
| from __future__ import print_function |
| import torch |
| import torch.nn as nn |
| from .model import* |
| import torch.nn.functional as F |
|
|
| def vq_cos_sim(embedding, input_tensor, use_dynamic_p=False,ddp=False): |
|
|
| if ddp: |
| embedding_weight = embedding.module.weight |
| else: |
| embedding_weight = embedding.weight |
|
|
| |
| input_norm = F.normalize(input_tensor, p=2, dim=2) |
| embedding_norm = F.normalize(embedding_weight, p=2, dim=1) |
|
|
| similarity = torch.matmul(input_norm, embedding_norm.t()) |
| cos_sim_values, indices = similarity.max(dim=2) |
|
|
| if use_dynamic_p: |
| |
| return indices.squeeze(), cos_sim_values.squeeze() |
| |
| return indices.squeeze() |
|
|
|
|
| class RatioLossWithMSELoss(nn.Module): |
| def __init__(self, total_iters, min_weight=0.001, max_weight=1,eps=torch.tensor(1e-3, dtype=torch.bfloat16)): |
| super(RatioLossWithMSELoss, self).__init__() |
| self.eps = eps |
| self.total_iters = total_iters |
| self.min_weight = min_weight |
| self.max_weight = max_weight |
| self.mse=nn.MSELoss() |
| self.losses={} |
| def forward(self, output, target, current_iter): |
|
|
| weight = self.min_weight + (self.max_weight - self.min_weight) * (current_iter / self.total_iters) |
| loss = (torch.abs(target - output)) / (torch.abs(target) + self.eps) |
| weighted_loss = weight * loss |
|
|
| self.losses['ratio']=loss.mean() |
| self.losses['mse']=self.mse(output,target) |
| return weighted_loss.mean()+self.mse(output,target) |
| |
|
|
|
|
|
|
| if __name__=='__main__': |
| pass |