| import torch |
| import torch.nn as nn |
|
|
| class ReConsLoss(nn.Module): |
| def __init__(self, recons_loss, nb_joints): |
| super(ReConsLoss, self).__init__() |
| |
| if recons_loss == 'l1': |
| self.Loss = torch.nn.L1Loss() |
| elif recons_loss == 'l2' : |
| self.Loss = torch.nn.MSELoss() |
| elif recons_loss == 'l1_smooth' : |
| self.Loss = torch.nn.SmoothL1Loss() |
| |
| |
| |
| |
| |
| self.nb_joints = nb_joints |
| self.motion_dim = (nb_joints - 1) * 12 + 4 + 3 + 4 |
| |
| def forward(self, motion_pred, motion_gt) : |
| loss = self.Loss(motion_pred[..., : self.motion_dim], motion_gt[..., :self.motion_dim]) |
| return loss |
| |
| def forward_joint(self, motion_pred, motion_gt) : |
| loss = self.Loss(motion_pred[..., 4 : (self.nb_joints - 1) * 3 + 4], motion_gt[..., 4 : (self.nb_joints - 1) * 3 + 4]) |
| return loss |
| |
| |