Spaces:
Sleeping
Sleeping
| # Multi-HMR | |
| # Copyright (c) 2024-present NAVER Corp. | |
| # CC BY-NC-SA 4.0 license | |
| import torch | |
| from torch import nn | |
| from torch import nn | |
| import smplx | |
| import torch | |
| import numpy as np | |
| import pose_utils | |
| from pose_utils import inverse_perspective_projection, perspective_projection | |
| import roma | |
| import pickle | |
| import os | |
| from pose_utils.constants_service import SMPLX_DIR | |
| from pose_utils.rot6d import rotation_6d_to_matrix | |
| from smplx.lbs import vertices2joints | |
| class SMPL_Layer(nn.Module): | |
| """ | |
| Extension of the SMPL Layer with information about the camera for (inverse) projection the camera plane. | |
| """ | |
| def __init__( | |
| self, | |
| smpl_dir, | |
| type="smplx", | |
| gender="neutral", | |
| num_betas=10, | |
| kid=False, | |
| person_center=None, | |
| *args, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| # Args | |
| assert type == "smplx" | |
| self.type = type | |
| self.kid = kid | |
| self.num_betas = num_betas | |
| self.bm_x = smplx.create( | |
| smpl_dir, "smplx", gender=gender, use_pca=False, flat_hand_mean=True, num_betas=num_betas | |
| ) | |
| # Primary keypoint - root | |
| self.joint_names = eval(f"pose_utils.get_{self.type}_joint_names")() | |
| self.person_center = person_center | |
| self.person_center_idx = None | |
| if self.person_center is not None: | |
| self.person_center_idx = self.joint_names.index(self.person_center) | |
| def forward( | |
| self, | |
| pose, | |
| shape, | |
| loc, | |
| dist, | |
| transl, | |
| K, | |
| expression=None, # facial expression | |
| rot6d=False, | |
| j_regressor=None, | |
| ): | |
| """ | |
| Args: | |
| - pose: pose of the person in axis-angle - torch.Tensor [bs,24,3] | |
| - shape: torch.Tensor [bs,10] | |
| - loc: 2D location of the pelvis in pixel space - torch.Tensor [bs,2] | |
| - dist: distance of the pelvis from the camera in m - torch.Tensor [bs,1] | |
| Return: | |
| - dict containing a bunch of useful information about each person | |
| """ | |
| if loc is not None and dist is not None: | |
| assert pose.shape[0] == shape.shape[0] == loc.shape[0] == dist.shape[0] | |
| POSE_TYPE_LENGTH = 6 if rot6d else 3 | |
| if self.type == "smpl": | |
| assert len(pose.shape) == 3 and list(pose.shape[1:]) == [24, POSE_TYPE_LENGTH] | |
| elif self.type == "smplx": | |
| assert len(pose.shape) == 3 and list(pose.shape[1:]) == [ | |
| 53, | |
| POSE_TYPE_LENGTH, | |
| ] # taking root_orient, body_pose, lhand, rhan and jaw for the moment | |
| else: | |
| raise NameError | |
| assert len(shape.shape) == 2 and ( | |
| list(shape.shape[1:]) == [self.num_betas] or list(shape.shape[1:]) == [self.num_betas + 1] | |
| ) | |
| if loc is not None and dist is not None: | |
| assert len(loc.shape) == 2 and list(loc.shape[1:]) == [2] | |
| assert len(dist.shape) == 2 and list(dist.shape[1:]) == [1] | |
| bs = pose.shape[0] | |
| out = {} | |
| # No humans | |
| if bs == 0: | |
| return {} | |
| # Low dimensional parameters | |
| kwargs_pose = { | |
| "betas": shape, | |
| } | |
| kwargs_pose["global_orient"] = self.bm_x.global_orient.repeat(bs, 1) | |
| kwargs_pose["body_pose"] = pose[:, 1:22].flatten(1) | |
| kwargs_pose["left_hand_pose"] = pose[:, 22:37].flatten(1) | |
| kwargs_pose["right_hand_pose"] = pose[:, 37:52].flatten(1) | |
| kwargs_pose["jaw_pose"] = pose[:, 52:53].flatten(1) | |
| if expression is not None: | |
| kwargs_pose["expression"] = expression.flatten(1) # [bs,10] | |
| else: | |
| kwargs_pose["expression"] = self.bm_x.expression.repeat(bs, 1) | |
| # default - to be generalized | |
| kwargs_pose["leye_pose"] = self.bm_x.leye_pose.repeat(bs, 1) | |
| kwargs_pose["reye_pose"] = self.bm_x.reye_pose.repeat(bs, 1) | |
| # kwargs_pose['pose2rot'] = not rot6d | |
| # Forward using the parametric 3d model SMPL-X layer | |
| output = self.bm_x(pose2rot=not rot6d, **kwargs_pose) | |
| verts = output.vertices | |
| j3d = output.joints # 45 joints | |
| if rot6d: | |
| R = rotation_6d_to_matrix(pose[:, 0]) | |
| else: | |
| R = roma.rotvec_to_rotmat(pose[:, 0]) | |
| # Apply global orientation on 3D points | |
| pelvis = j3d[:, [0]] | |
| j3d = (R.unsqueeze(1) @ (j3d - pelvis).unsqueeze(-1)).squeeze(-1) | |
| # Apply global orientation on 3D points - bis | |
| verts = (R.unsqueeze(1) @ (verts - pelvis).unsqueeze(-1)).squeeze(-1) | |
| # Location of the person in 3D | |
| if transl is None: | |
| if K.dtype == torch.float16: | |
| # because of torch.inverse - not working with float16 at the moment | |
| transl = inverse_perspective_projection( | |
| loc.unsqueeze(1).float(), K.float(), dist.unsqueeze(1).float() | |
| )[:, 0] | |
| transl = transl.half() | |
| else: | |
| transl = inverse_perspective_projection(loc.unsqueeze(1), K, dist.unsqueeze(1))[:, 0] | |
| # Updating transl if we choose a certain person center | |
| transl_up = transl.clone() | |
| # Definition of the translation depend on the args: 1) vanilla SMPL - 2) computed from a given joint | |
| if self.person_center_idx is None: | |
| # Add pelvis to transl - standard way for SMPLX layer | |
| transl_up = transl_up + pelvis[:, 0] | |
| else: | |
| # Center around the joint because teh translation is computed from this joint | |
| person_center = j3d[:, [self.person_center_idx]] | |
| verts = verts - person_center | |
| j3d = j3d - person_center | |
| # Moving into the camera coordinate system | |
| j3d_cam = j3d + transl_up.unsqueeze(1) | |
| verts_cam = verts + transl_up.unsqueeze(1) | |
| # Projection in camera plane | |
| if j_regressor is not None: | |
| # for smplify | |
| j3d_cam = vertices2joints(j_regressor, verts_cam) | |
| j2d = perspective_projection(j3d_cam, K) | |
| v2d = perspective_projection(verts_cam, K) | |
| out.update( | |
| { | |
| "v3d": verts_cam, # in 3d camera space | |
| "j3d": j3d_cam, # in 3d camera space | |
| "j2d": j2d, | |
| "v2d": v2d, | |
| "transl": transl, # translation of the primary keypoint | |
| "transl_pelvis": transl.unsqueeze(1) - person_center - pelvis, # root=pelvis | |
| "j3d_world": output.joints, | |
| } | |
| ) | |
| return out | |
| def forward_local(self, pose, shape): | |
| N, J, L = pose.shape | |
| if N < 1: | |
| return None | |
| kwargs_pose = { | |
| "betas": shape, | |
| } | |
| if J == 53: | |
| kwargs_pose["global_orient"] = self.bm_x.global_orient.repeat(N, 1) | |
| kwargs_pose["body_pose"] = pose[:, 1:22].flatten(1) | |
| kwargs_pose["left_hand_pose"] = pose[:, 22:37].flatten(1) | |
| kwargs_pose["right_hand_pose"] = pose[:, 37:52].flatten(1) | |
| kwargs_pose["jaw_pose"] = pose[:, 52:53].flatten(1) | |
| elif J==55: | |
| kwargs_pose["global_orient"] = self.bm_x.global_orient.repeat(N, 1) | |
| kwargs_pose["body_pose"] = pose[:, 1:22].flatten(1) | |
| kwargs_pose["left_hand_pose"] = pose[:, 25:40].flatten(1) | |
| kwargs_pose["right_hand_pose"] = pose[:, 40:55].flatten(1) | |
| kwargs_pose["jaw_pose"] = pose[:, 22:23].flatten(1) | |
| else: | |
| raise ValueError(f"pose dim error, should be 53 or 55, but got {J}") | |
| kwargs_pose["expression"] = self.bm_x.expression.repeat(N, 1) | |
| # default - to be generalized | |
| kwargs_pose["leye_pose"] = self.bm_x.leye_pose.repeat(N, 1) | |
| kwargs_pose["reye_pose"] = self.bm_x.reye_pose.repeat(N, 1) | |
| output = self.bm_x(**kwargs_pose) | |
| return output | |
| def convert_standard_pose(self, poses): | |
| # pose: N, J, 3 | |
| n = poses.shape[0] | |
| poses = torch.cat( | |
| [ | |
| poses[:, :22], | |
| poses[:, 52:53], | |
| self.bm_x.leye_pose.repeat(n, 1, 1), | |
| self.bm_x.reye_pose.repeat(n, 1, 1), | |
| poses[:, 22:52], | |
| ], | |
| dim=1, | |
| ) | |
| return poses | |