motion-stream / utils /losses.py
zirobtc's picture
Initial upload of MotionStreamer code, excluding large extracted data and output folders.
0e267a7 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class ReConsLoss(nn.Module):
def __init__(self, motion_dim=272):
super(ReConsLoss, self).__init__()
self.motion_dim = motion_dim
def softclip(self, tensor, min):
result_tensor = min + F.softplus(tensor - min)
return result_tensor
def gaussian_nll(self, mu, log_sigma, x):
return 0.5 * torch.pow((x - mu) / log_sigma.exp(), 2) + log_sigma + 0.5 * np.log(2 * np.pi)
def forward(self, motion_pred, motion_gt) :
"""Optimal sigma VAE loss, see https://arxiv.org/pdf/2006.13202 for more details"""
log_sigma = ((motion_gt[..., :self.motion_dim] - motion_pred[..., :self.motion_dim]) ** 2).mean([0,1,2], keepdim=True).sqrt().log()
log_sigma = self.softclip(log_sigma, -6)
loss = self.gaussian_nll(motion_pred[..., :self.motion_dim], log_sigma, motion_gt[..., :self.motion_dim]).sum()
return loss
def forward_KL(self, mu, logvar):
loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=(1, 2))
return loss.mean()
def forward_root(self, motion_pred, motion_gt):
"""[..., :8] relate to the root joint"""
root_log_sigma = ((motion_gt[..., :8] - motion_pred[..., :8]) ** 2).mean([0,1,2], keepdim=True).sqrt().log()
root_log_sigma = self.softclip(root_log_sigma, -6)
root_loss = self.gaussian_nll(motion_pred[..., :8], root_log_sigma, motion_gt[..., :8]).sum()
return root_loss