Spaces:
Sleeping
Sleeping
| from __future__ import absolute_import | |
| from __future__ import print_function | |
| from __future__ import division | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| from configs import constants as _C | |
| from lib.utils import transforms | |
| from lib.utils.kp_utils import root_centering | |
| class WHAMLoss(nn.Module): | |
| def __init__( | |
| self, | |
| cfg=None, | |
| device=None, | |
| ): | |
| super(WHAMLoss, self).__init__() | |
| self.cfg = cfg | |
| self.n_joints = _C.KEYPOINTS.NUM_JOINTS | |
| self.criterion = nn.MSELoss() | |
| self.criterion_noreduce = nn.MSELoss(reduction='none') | |
| self.pose_loss_weight = cfg.LOSS.POSE_LOSS_WEIGHT | |
| self.shape_loss_weight = cfg.LOSS.SHAPE_LOSS_WEIGHT | |
| self.keypoint_2d_loss_weight = cfg.LOSS.JOINT2D_LOSS_WEIGHT | |
| self.keypoint_3d_loss_weight = cfg.LOSS.JOINT3D_LOSS_WEIGHT | |
| self.cascaded_loss_weight = cfg.LOSS.CASCADED_LOSS_WEIGHT | |
| self.vertices_loss_weight = cfg.LOSS.VERTS3D_LOSS_WEIGHT | |
| self.contact_loss_weight = cfg.LOSS.CONTACT_LOSS_WEIGHT | |
| self.root_vel_loss_weight = cfg.LOSS.ROOT_VEL_LOSS_WEIGHT | |
| self.root_pose_loss_weight = cfg.LOSS.ROOT_POSE_LOSS_WEIGHT | |
| self.sliding_loss_weight = cfg.LOSS.SLIDING_LOSS_WEIGHT | |
| self.camera_loss_weight = cfg.LOSS.CAMERA_LOSS_WEIGHT | |
| self.loss_weight = cfg.LOSS.LOSS_WEIGHT | |
| kp_weights = [ | |
| 0.5, 0.5, 0.5, 0.5, 0.5, # Face | |
| 1.5, 1.5, 4, 4, 4, 4, # Arms | |
| 1.5, 1.5, 4, 4, 4, 4, # Legs | |
| 4, 4, 1.5, 1.5, 4, 4, # Legs | |
| 4, 4, 1.5, 1.5, 4, 4, # Arms | |
| 0.5, 0.5 # Head | |
| ] | |
| theta_weights = [ | |
| 0.1, 1.0, 1.0, 1.0, 1.0, # pelvis, lhip, rhip, spine1, lknee | |
| 1.0, 1.0, 1.0, 1.0, 1.0, # rknn, spine2, lankle, rankle, spin3 | |
| 0.1, 0.1, # Foot | |
| 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, # neck, lisldr, risldr, head, losldr, rosldr, | |
| 1.0, 1.0, 1.0, 1.0, # lelbow, relbow, lwrist, rwrist | |
| 0.1, 0.1, # Hand | |
| ] | |
| self.theta_weights = torch.tensor([[theta_weights]]).float().to(device) | |
| self.theta_weights /= self.theta_weights.mean() | |
| self.kp_weights = torch.tensor([kp_weights]).float().to(device) | |
| self.epoch = -1 | |
| self.step() | |
| def step(self): | |
| self.epoch += 1 | |
| self.skip_camera_loss = self.epoch < self.cfg.LOSS.CAMERA_LOSS_SKIP_EPOCH | |
| def forward(self, pred, gt): | |
| loss = 0.0 | |
| b, f = gt['kp3d'].shape[:2] | |
| # <======= Predictions and Groundtruths | |
| pred_betas = pred['betas'] | |
| pred_pose = pred['pose'].reshape(b, f, -1, 6) | |
| pred_kp3d_nn = pred['kp3d_nn'] | |
| pred_kp3d_smpl = root_centering(pred['kp3d'].reshape(b, f, -1, 3)) | |
| pred_full_kp2d = pred['full_kp2d'] | |
| pred_weak_kp2d = pred['weak_kp2d'] | |
| pred_contact = pred['contact'] | |
| pred_vel_root = pred['vel_root'] | |
| pred_pose_root = pred['poses_root_r6d'][:, 1:] | |
| pred_vel_root_ref = pred['vel_root_refined'] | |
| pred_pose_root_ref = pred['poses_root_r6d_refined'][:, 1:] | |
| pred_cam_r = transforms.matrix_to_rotation_6d(pred['R']) | |
| gt_betas = gt['betas'] | |
| gt_pose = gt['pose'] | |
| gt_kp3d = root_centering(gt['kp3d']) | |
| gt_full_kp2d = gt['full_kp2d'] | |
| gt_weak_kp2d = gt['weak_kp2d'] | |
| gt_contact = gt['contact'] | |
| gt_vel_root = gt['vel_root'] | |
| gt_pose_root = gt['pose_root'][:, 1:] | |
| gt_cam_angvel = gt['cam_angvel'] | |
| gt_cam_r = transforms.matrix_to_rotation_6d(gt['R'][:, 1:]) | |
| bbox = gt['bbox'] | |
| # =======> | |
| loss_keypoints_full = full_projected_keypoint_loss( | |
| pred_full_kp2d, | |
| gt_full_kp2d, | |
| bbox, | |
| self.kp_weights, | |
| criterion=self.criterion_noreduce, | |
| ) | |
| loss_keypoints_weak = weak_projected_keypoint_loss( | |
| pred_weak_kp2d, | |
| gt_weak_kp2d, | |
| self.kp_weights, | |
| criterion=self.criterion_noreduce | |
| ) | |
| # Compute 3D keypoint loss | |
| loss_keypoints_3d_nn = keypoint_3d_loss( | |
| pred_kp3d_nn, | |
| gt_kp3d[:, :, :self.n_joints], | |
| self.kp_weights[:, :self.n_joints], | |
| criterion=self.criterion_noreduce, | |
| ) | |
| loss_keypoints_3d_smpl = keypoint_3d_loss( | |
| pred_kp3d_smpl, | |
| gt_kp3d, | |
| self.kp_weights, | |
| criterion=self.criterion_noreduce, | |
| ) | |
| loss_cascaded = keypoint_3d_loss( | |
| pred_kp3d_nn, | |
| torch.cat((pred_kp3d_smpl[:, :, :self.n_joints], gt_kp3d[:, :, :self.n_joints, -1:]), dim=-1), | |
| self.kp_weights[:, :self.n_joints] * 0.5, | |
| criterion=self.criterion_noreduce, | |
| ) | |
| loss_vertices = vertices_loss( | |
| pred['verts_cam'], | |
| gt['verts'], | |
| gt['has_verts'], | |
| criterion=self.criterion_noreduce, | |
| ) | |
| # Compute loss on SMPL parameters | |
| smpl_mask = gt['has_smpl'] | |
| loss_regr_pose, loss_regr_betas = smpl_losses( | |
| pred_pose, | |
| pred_betas, | |
| gt_pose, | |
| gt_betas, | |
| self.theta_weights, | |
| smpl_mask, | |
| criterion=self.criterion_noreduce | |
| ) | |
| # Compute loss on foot contact | |
| loss_contact = contact_loss( | |
| pred_contact, | |
| gt_contact, | |
| self.criterion_noreduce | |
| ) | |
| # Compute loss on root velocity and angular velocity | |
| loss_vel_root, loss_pose_root = root_loss( | |
| pred_vel_root, | |
| pred_pose_root, | |
| gt_vel_root, | |
| gt_pose_root, | |
| gt_contact, | |
| self.criterion_noreduce | |
| ) | |
| # Root loss after trajectory refinement | |
| loss_vel_root_ref, loss_pose_root_ref = root_loss( | |
| pred_vel_root_ref, | |
| pred_pose_root_ref, | |
| gt_vel_root, | |
| gt_pose_root, | |
| gt_contact, | |
| self.criterion_noreduce | |
| ) | |
| # Camera prediction loss | |
| loss_camera = camera_loss( | |
| pred_cam_r, | |
| gt_cam_r, | |
| gt_cam_angvel[:, 1:], | |
| gt['has_traj'], | |
| self.criterion_noreduce, | |
| self.skip_camera_loss | |
| ) | |
| # Foot sliding loss | |
| loss_sliding = sliding_loss( | |
| pred['feet'], | |
| gt_contact, | |
| ) | |
| # Foot sliding loss | |
| loss_sliding_ref = sliding_loss( | |
| pred['feet_refined'], | |
| gt_contact, | |
| ) | |
| loss_keypoints = loss_keypoints_full + loss_keypoints_weak | |
| loss_keypoints *= self.keypoint_2d_loss_weight | |
| loss_keypoints_3d_smpl *= self.keypoint_3d_loss_weight | |
| loss_keypoints_3d_nn *= self.keypoint_3d_loss_weight | |
| loss_cascaded *= self.cascaded_loss_weight | |
| loss_vertices *= self.vertices_loss_weight | |
| loss_contact *= self.contact_loss_weight | |
| loss_root = loss_vel_root * self.root_vel_loss_weight + loss_pose_root * self.root_pose_loss_weight | |
| loss_root_ref = loss_vel_root_ref * self.root_vel_loss_weight + loss_pose_root_ref * self.root_pose_loss_weight | |
| loss_regr_pose *= self.pose_loss_weight | |
| loss_regr_betas *= self.shape_loss_weight | |
| loss_sliding *= self.sliding_loss_weight | |
| loss_camera *= self.camera_loss_weight | |
| loss_sliding_ref *= self.sliding_loss_weight | |
| loss_dict = { | |
| 'pose': loss_regr_pose * self.loss_weight, | |
| 'betas': loss_regr_betas * self.loss_weight, | |
| '2d': loss_keypoints * self.loss_weight, | |
| '3d': loss_keypoints_3d_smpl * self.loss_weight, | |
| '3d_nn': loss_keypoints_3d_nn * self.loss_weight, | |
| 'casc': loss_cascaded * self.loss_weight, | |
| 'v3d': loss_vertices * self.loss_weight, | |
| 'contact': loss_contact * self.loss_weight, | |
| 'root': loss_root * self.loss_weight, | |
| 'root_ref': loss_root_ref * self.loss_weight, | |
| 'sliding': loss_sliding * self.loss_weight, | |
| 'camera': loss_camera * self.loss_weight, | |
| 'sliding_ref': loss_sliding_ref * self.loss_weight, | |
| } | |
| loss = sum(loss for loss in loss_dict.values()) | |
| return loss, loss_dict | |
| def root_loss( | |
| pred_vel_root, | |
| pred_pose_root, | |
| gt_vel_root, | |
| gt_pose_root, | |
| stationary, | |
| criterion | |
| ): | |
| mask_r = (gt_pose_root != 0.0).all(dim=-1).all(dim=-1) | |
| mask_v = (gt_vel_root != 0.0).all(dim=-1).all(dim=-1) | |
| mask_s = (stationary != -1).any(dim=1).any(dim=1) | |
| mask_v = mask_v * mask_s | |
| if mask_r.any(): | |
| loss_r = criterion(pred_pose_root, gt_pose_root)[mask_r].mean() | |
| else: | |
| loss_r = torch.FloatTensor(1).fill_(0.).to(gt_pose_root.device)[0] | |
| if mask_v.any(): | |
| loss_v = 0 | |
| T = gt_vel_root.shape[0] | |
| ws_list = [1, 3, 9, 27] | |
| for ws in ws_list: | |
| tmp_v = 0 | |
| for m in range(T//ws): | |
| cumulative_v = torch.sum(pred_vel_root[:, m:(m+1)*ws] - gt_vel_root[:, m:(m+1)*ws], dim=1) | |
| tmp_v += torch.norm(cumulative_v, dim=-1) | |
| loss_v += tmp_v | |
| loss_v = loss_v[mask_v].mean() | |
| else: | |
| loss_v = torch.FloatTensor(1).fill_(0.).to(gt_vel_root.device)[0] | |
| return loss_v, loss_r | |
| def contact_loss( | |
| pred_stationary, | |
| gt_stationary, | |
| criterion, | |
| ): | |
| mask = gt_stationary != -1 | |
| if mask.any(): | |
| loss = criterion(pred_stationary, gt_stationary)[mask].mean() | |
| else: | |
| loss = torch.FloatTensor(1).fill_(0.).to(gt_stationary.device)[0] | |
| return loss | |
| def full_projected_keypoint_loss( | |
| pred_keypoints_2d, | |
| gt_keypoints_2d, | |
| bbox, | |
| weight, | |
| criterion, | |
| ): | |
| scale = bbox[..., 2:] * 200. | |
| conf = gt_keypoints_2d[..., -1] | |
| if (conf > 0).any(): | |
| loss = torch.mean( | |
| weight * (conf * torch.norm(pred_keypoints_2d - gt_keypoints_2d[..., :2], dim=-1) | |
| ) / scale, dim=1).mean() * conf.mean() | |
| else: | |
| loss = torch.FloatTensor(1).fill_(0.).to(gt_keypoints_2d.device)[0] | |
| return loss | |
| def weak_projected_keypoint_loss( | |
| pred_keypoints_2d, | |
| gt_keypoints_2d, | |
| weight, | |
| criterion, | |
| ): | |
| conf = gt_keypoints_2d[..., -1] | |
| if (conf > 0).any(): | |
| loss = torch.mean( | |
| weight * (conf * torch.norm(pred_keypoints_2d - gt_keypoints_2d[..., :2], dim=-1) | |
| ), dim=1).mean() * conf.mean() * 5 | |
| else: | |
| loss = torch.FloatTensor(1).fill_(0.).to(gt_keypoints_2d.device)[0] | |
| return loss | |
| def keypoint_3d_loss( | |
| pred_keypoints_3d, | |
| gt_keypoints_3d, | |
| weight, | |
| criterion, | |
| ): | |
| conf = gt_keypoints_3d[..., -1] | |
| if (conf > 0).any(): | |
| if weight.shape[-2] > 17: | |
| pred_keypoints_3d[..., -14:] = pred_keypoints_3d[..., -14:] - pred_keypoints_3d[..., -14:].mean(dim=-2, keepdims=True) | |
| gt_keypoints_3d[..., -14:] = gt_keypoints_3d[..., -14:] - gt_keypoints_3d[..., -14:].mean(dim=-2, keepdims=True) | |
| loss = torch.mean( | |
| weight * (conf * torch.norm(pred_keypoints_3d - gt_keypoints_3d[..., :3], dim=-1) | |
| ), dim=1).mean() * conf.mean() | |
| else: | |
| loss = torch.FloatTensor(1).fill_(0.).to(gt_keypoints_3d.device)[0] | |
| return loss | |
| def vertices_loss( | |
| pred_verts, | |
| gt_verts, | |
| mask, | |
| criterion, | |
| ): | |
| if mask.sum() > 0: | |
| # Align | |
| pred_verts = pred_verts.view_as(gt_verts) | |
| pred_verts = pred_verts - pred_verts.mean(-2, True) | |
| gt_verts = gt_verts - gt_verts.mean(-2, True) | |
| # loss = criterion(pred_verts, gt_verts).mean() * mask.float().mean() | |
| # loss = torch.mean( | |
| # `(torch.norm(pred_verts - gt_verts, dim=-1)[mask]` | |
| # ), dim=1).mean() * mask.float().mean() | |
| loss = torch.mean( | |
| (torch.norm(pred_verts - gt_verts, p=1, dim=-1)[mask] | |
| ), dim=1).mean() * mask.float().mean() | |
| else: | |
| loss = torch.FloatTensor(1).fill_(0.).to(gt_verts.device)[0] | |
| return loss | |
| def smpl_losses( | |
| pred_pose, | |
| pred_betas, | |
| gt_pose, | |
| gt_betas, | |
| weight, | |
| mask, | |
| criterion, | |
| ): | |
| if mask.any().item(): | |
| loss_regr_pose = torch.mean( | |
| weight * torch.square(pred_pose - gt_pose)[mask].mean(-1) | |
| ) * mask.float().mean() | |
| loss_regr_betas = F.mse_loss(pred_betas, gt_betas, reduction='none')[mask].mean() * mask.float().mean() | |
| else: | |
| loss_regr_pose = torch.FloatTensor(1).fill_(0.).to(gt_pose.device)[0] | |
| loss_regr_betas = torch.FloatTensor(1).fill_(0.).to(gt_pose.device)[0] | |
| return loss_regr_pose, loss_regr_betas | |
| def camera_loss( | |
| pred_cam_r, | |
| gt_cam_r, | |
| cam_angvel, | |
| mask, | |
| criterion, | |
| skip | |
| ): | |
| # mask = (gt_cam_r != 0.0).all(dim=-1).all(dim=-1) | |
| if mask.any() and not skip: | |
| # Camera pose loss in 6D representation | |
| loss_r = criterion(pred_cam_r, gt_cam_r)[mask].mean() | |
| # Reconstruct camera angular velocity and compute reconstruction loss | |
| pred_R = transforms.rotation_6d_to_matrix(pred_cam_r) | |
| cam_angvel_from_R = transforms.matrix_to_rotation_6d(pred_R[:, :-1] @ pred_R[:, 1:].transpose(-1, -2)) | |
| cam_angvel_from_R = (cam_angvel_from_R - torch.tensor([[[1, 0, 0, 0, 1, 0]]]).to(cam_angvel)) * 30 | |
| loss_a = criterion(cam_angvel, cam_angvel_from_R)[mask].mean() | |
| loss = loss_r + loss_a | |
| else: | |
| loss = torch.FloatTensor(1).fill_(0.).to(gt_cam_r.device)[0] | |
| return loss | |
| def sliding_loss( | |
| foot_position, | |
| contact_prob, | |
| ): | |
| """ Compute foot skate loss when foot is assumed to be on contact with ground | |
| foot_position: 3D foot (heel and toe) position, torch.Tensor (B, F, 4, 3) | |
| contact_prob: contact probability of foot (heel and toe), torch.Tensor (B, F, 4) | |
| """ | |
| contact_mask = (contact_prob > 0.5).detach().float() | |
| foot_velocity = foot_position[:, 1:] - foot_position[:, :-1] | |
| loss = (torch.norm(foot_velocity, dim=-1) * contact_mask[:, 1:]).mean() | |
| return loss | |