| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| class ReConsLoss(nn.Module): | |
| def __init__(self, motion_dim=272): | |
| super(ReConsLoss, self).__init__() | |
| self.motion_dim = motion_dim | |
| def softclip(self, tensor, min): | |
| result_tensor = min + F.softplus(tensor - min) | |
| return result_tensor | |
| def gaussian_nll(self, mu, log_sigma, x): | |
| return 0.5 * torch.pow((x - mu) / log_sigma.exp(), 2) + log_sigma + 0.5 * np.log(2 * np.pi) | |
| def forward(self, motion_pred, motion_gt) : | |
| """Optimal sigma VAE loss, see https://arxiv.org/pdf/2006.13202 for more details""" | |
| log_sigma = ((motion_gt[..., :self.motion_dim] - motion_pred[..., :self.motion_dim]) ** 2).mean([0,1,2], keepdim=True).sqrt().log() | |
| log_sigma = self.softclip(log_sigma, -6) | |
| loss = self.gaussian_nll(motion_pred[..., :self.motion_dim], log_sigma, motion_gt[..., :self.motion_dim]).sum() | |
| return loss | |
| def forward_KL(self, mu, logvar): | |
| loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=(1, 2)) | |
| return loss.mean() | |
| def forward_root(self, motion_pred, motion_gt): | |
| """[..., :8] relate to the root joint""" | |
| root_log_sigma = ((motion_gt[..., :8] - motion_pred[..., :8]) ** 2).mean([0,1,2], keepdim=True).sqrt().log() | |
| root_log_sigma = self.softclip(root_log_sigma, -6) | |
| root_loss = self.gaussian_nll(motion_pred[..., :8], root_log_sigma, motion_gt[..., :8]).sum() | |
| return root_loss | |