File size: 1,571 Bytes
d2a17a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class ReConsLoss(nn.Module):
    def __init__(self, motion_dim=272):
        super(ReConsLoss, self).__init__()
        self.motion_dim = motion_dim
    
    def softclip(self, tensor, min):
        result_tensor = min + F.softplus(tensor - min)
        return result_tensor
    
    def gaussian_nll(self, mu, log_sigma, x):
        return 0.5 * torch.pow((x - mu) / log_sigma.exp(), 2) + log_sigma + 0.5 * np.log(2 * np.pi)
    
    def forward(self, motion_pred, motion_gt) : 
        """Optimal sigma VAE loss, see https://arxiv.org/pdf/2006.13202 for more details"""
        log_sigma = ((motion_gt[..., :self.motion_dim] - motion_pred[..., :self.motion_dim]) ** 2).mean([0,1,2], keepdim=True).sqrt().log()
        log_sigma = self.softclip(log_sigma, -6)
        loss = self.gaussian_nll(motion_pred[..., :self.motion_dim], log_sigma, motion_gt[..., :self.motion_dim]).sum()
        return loss
    
    
    def forward_KL(self, mu, logvar):
        loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=(1, 2))
        return loss.mean()
    
    def forward_root(self, motion_pred, motion_gt):
        """[..., :8] relate to the root joint"""
        root_log_sigma = ((motion_gt[..., :8] - motion_pred[..., :8]) ** 2).mean([0,1,2], keepdim=True).sqrt().log()
        root_log_sigma = self.softclip(root_log_sigma, -6)
        root_loss = self.gaussian_nll(motion_pred[..., :8], root_log_sigma, motion_gt[..., :8]).sum()           
        return root_loss