Spaces:
Sleeping
Sleeping
| import torch | |
| def gmof(x, sigma): | |
| """ | |
| Geman-McClure error function | |
| """ | |
| x_squared = x ** 2 | |
| sigma_squared = sigma ** 2 | |
| return (sigma_squared * x_squared) / (sigma_squared + x_squared) | |
| def compute_jitter(x): | |
| """ | |
| Compute jitter for the input tensor | |
| """ | |
| return torch.linalg.norm(x[:, 2:] + x[:, :-2] - 2 * x[:, 1:-1], dim=-1) | |
| class SMPLifyLoss(torch.nn.Module): | |
| def __init__(self, | |
| res, | |
| cam_intrinsics, | |
| init_pose, | |
| device, | |
| **kwargs | |
| ): | |
| super().__init__() | |
| self.res = res | |
| self.cam_intrinsics = cam_intrinsics | |
| self.init_pose = torch.from_numpy(init_pose).float().to(device) | |
| def forward(self, output, params, input_keypoints, bbox, | |
| reprojection_weight=100., regularize_weight=60.0, | |
| consistency_weight=10.0, sprior_weight=0.04, | |
| smooth_weight=20.0, sigma=100): | |
| pose, shape, cam = params | |
| scale = bbox[..., 2:].unsqueeze(-1) * 200. | |
| # Loss 1. Data term | |
| pred_keypoints = output.full_joints2d[..., :17, :] | |
| joints_conf = input_keypoints[..., -1:] | |
| reprojection_error = gmof(pred_keypoints - input_keypoints[..., :-1], sigma) | |
| reprojection_error = ((reprojection_error * joints_conf) / scale).mean() | |
| # Loss 2. Regularization term | |
| regularize_error = torch.linalg.norm(pose - self.init_pose, dim=-1).mean() | |
| # Loss 3. Shape prior and consistency error | |
| consistency_error = shape.std(dim=1).mean() | |
| sprior_error = torch.linalg.norm(shape, dim=-1).mean() | |
| shape_error = sprior_weight * sprior_error + consistency_weight * consistency_error | |
| # Loss 4. Smooth loss | |
| pose_diff = compute_jitter(pose).mean() | |
| cam_diff = compute_jitter(cam).mean() | |
| smooth_error = pose_diff + cam_diff | |
| # Sum up losses | |
| loss = { | |
| 'reprojection': reprojection_weight * reprojection_error, | |
| 'regularize': regularize_weight * regularize_error, | |
| 'shape': shape_error, | |
| 'smooth': smooth_weight * smooth_error | |
| } | |
| return loss | |
| def create_closure(self, | |
| optimizer, | |
| smpl, | |
| params, | |
| bbox, | |
| input_keypoints): | |
| def closure(): | |
| optimizer.zero_grad() | |
| output = smpl(*params, cam_intrinsics=self.cam_intrinsics, bbox=bbox, res=self.res) | |
| loss_dict = self.forward(output, params, input_keypoints, bbox) | |
| loss = sum(loss_dict.values()) | |
| loss.backward() | |
| return loss | |
| return closure |