| |
|
|
| from typing import Optional |
| from dataclasses import dataclass |
|
|
| import os |
| import torch |
| import torch.nn as nn |
| import numpy as np |
| import pickle |
| from lib.smplx import SMPL as _SMPL |
| from lib.smplx import SMPLXLayer, MANOLayer, FLAMELayer |
| from lib.smplx.lbs import batch_rodrigues, transform_mat, vertices2joints, blend_shapes |
| from lib.smplx.body_models import SMPLXOutput |
| import json |
|
|
| from lib.pymafx.core import path_config, constants |
|
|
| SMPL_MEAN_PARAMS = path_config.SMPL_MEAN_PARAMS |
| SMPL_MODEL_DIR = path_config.SMPL_MODEL_DIR |
|
|
|
|
| @dataclass |
| class ModelOutput(SMPLXOutput): |
| smpl_joints: Optional[torch.Tensor] = None |
| joints_J19: Optional[torch.Tensor] = None |
| smplx_vertices: Optional[torch.Tensor] = None |
| flame_vertices: Optional[torch.Tensor] = None |
| lhand_vertices: Optional[torch.Tensor] = None |
| rhand_vertices: Optional[torch.Tensor] = None |
| lhand_joints: Optional[torch.Tensor] = None |
| rhand_joints: Optional[torch.Tensor] = None |
| face_joints: Optional[torch.Tensor] = None |
| lfoot_joints: Optional[torch.Tensor] = None |
| rfoot_joints: Optional[torch.Tensor] = None |
|
|
|
|
| class SMPL(_SMPL): |
| """ Extension of the official SMPL implementation to support more joints """ |
| def __init__( |
| self, |
| create_betas=False, |
| create_global_orient=False, |
| create_body_pose=False, |
| create_transl=False, |
| *args, |
| **kwargs |
| ): |
| super().__init__( |
| create_betas=create_betas, |
| create_global_orient=create_global_orient, |
| create_body_pose=create_body_pose, |
| create_transl=create_transl, |
| *args, |
| **kwargs |
| ) |
| joints = [constants.JOINT_MAP[i] for i in constants.JOINT_NAMES] |
| J_regressor_extra = np.load(path_config.JOINT_REGRESSOR_TRAIN_EXTRA) |
| self.register_buffer( |
| 'J_regressor_extra', torch.tensor(J_regressor_extra, dtype=torch.float32) |
| ) |
| self.joint_map = torch.tensor(joints, dtype=torch.long) |
| |
| |
|
|
| tpose_joints = vertices2joints(self.J_regressor, self.v_template.unsqueeze(0)) |
| self.register_buffer('tpose_joints', tpose_joints) |
|
|
| def forward(self, *args, **kwargs): |
| kwargs['get_skin'] = True |
| smpl_output = super().forward(*args, **kwargs) |
| extra_joints = vertices2joints(self.J_regressor_extra, smpl_output.vertices) |
| |
| vertices = smpl_output.vertices |
| joints = torch.cat([smpl_output.joints, extra_joints], dim=1) |
| smpl_joints = smpl_output.joints[:, :24] |
| joints = joints[:, self.joint_map, :] |
| joints_J24 = joints[:, -24:, :] |
| joints_J19 = joints_J24[:, constants.J24_TO_J19, :] |
| output = ModelOutput( |
| vertices=vertices, |
| global_orient=smpl_output.global_orient, |
| body_pose=smpl_output.body_pose, |
| joints=joints, |
| joints_J19=joints_J19, |
| smpl_joints=smpl_joints, |
| betas=smpl_output.betas, |
| full_pose=smpl_output.full_pose |
| ) |
| return output |
|
|
| def get_global_rotation( |
| self, |
| global_orient: Optional[torch.Tensor] = None, |
| body_pose: Optional[torch.Tensor] = None, |
| **kwargs |
| ): |
| ''' |
| Forward pass for the SMPLX model |
| |
| Parameters |
| ---------- |
| global_orient: torch.tensor, optional, shape Bx3x3 |
| If given, ignore the member variable and use it as the global |
| rotation of the body. Useful if someone wishes to predicts this |
| with an external model. It is expected to be in rotation matrix |
| format. (default=None) |
| body_pose: torch.tensor, optional, shape BxJx3x3 |
| If given, ignore the member variable `body_pose` and use it |
| instead. For example, it can used if someone predicts the |
| pose of the body joints are predicted from some external model. |
| It should be a tensor that contains joint rotations in |
| rotation matrix format. (default=None) |
| Returns |
| ------- |
| output: Global rotation matrix |
| ''' |
| device, dtype = self.shapedirs.device, self.shapedirs.dtype |
|
|
| model_vars = [global_orient, body_pose] |
| batch_size = 1 |
| for var in model_vars: |
| if var is None: |
| continue |
| batch_size = max(batch_size, len(var)) |
|
|
| if global_orient is None: |
| global_orient = torch.eye(3, device=device, |
| dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, |
| -1).contiguous() |
| if body_pose is None: |
| body_pose = torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand( |
| batch_size, self.NUM_BODY_JOINTS, -1, -1 |
| ).contiguous() |
|
|
| |
| full_pose = torch.cat( |
| [global_orient.reshape(-1, 1, 3, 3), |
| body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3)], |
| dim=1 |
| ) |
|
|
| rot_mats = full_pose.view(batch_size, -1, 3, 3) |
|
|
| |
| |
| |
| |
|
|
| joints = self.tpose_joints.expand(batch_size, -1, -1).unsqueeze(-1) |
|
|
| rel_joints = joints.clone() |
| rel_joints[:, 1:] -= joints[:, self.parents[1:]] |
|
|
| transforms_mat = transform_mat(rot_mats.reshape(-1, 3, 3), |
| rel_joints.reshape(-1, 3, |
| 1)).reshape(-1, joints.shape[1], 4, 4) |
|
|
| transform_chain = [transforms_mat[:, 0]] |
| for i in range(1, self.parents.shape[0]): |
| |
| |
| curr_res = torch.matmul(transform_chain[self.parents[i]], transforms_mat[:, i]) |
| transform_chain.append(curr_res) |
|
|
| transforms = torch.stack(transform_chain, dim=1) |
|
|
| global_rotmat = transforms[:, :, :3, :3] |
|
|
| |
| posed_joints = transforms[:, :, :3, 3] |
|
|
| return global_rotmat, posed_joints |
|
|
|
|
| class SMPLX(SMPLXLayer): |
| """ Extension of the official SMPLX implementation to support more functions """ |
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
|
|
| def get_global_rotation( |
| self, |
| global_orient: Optional[torch.Tensor] = None, |
| body_pose: Optional[torch.Tensor] = None, |
| left_hand_pose: Optional[torch.Tensor] = None, |
| right_hand_pose: Optional[torch.Tensor] = None, |
| jaw_pose: Optional[torch.Tensor] = None, |
| leye_pose: Optional[torch.Tensor] = None, |
| reye_pose: Optional[torch.Tensor] = None, |
| **kwargs |
| ): |
| ''' |
| Forward pass for the SMPLX model |
| |
| Parameters |
| ---------- |
| global_orient: torch.tensor, optional, shape Bx3x3 |
| If given, ignore the member variable and use it as the global |
| rotation of the body. Useful if someone wishes to predicts this |
| with an external model. It is expected to be in rotation matrix |
| format. (default=None) |
| betas: torch.tensor, optional, shape BxN_b |
| If given, ignore the member variable `betas` and use it |
| instead. For example, it can used if shape parameters |
| `betas` are predicted from some external model. |
| (default=None) |
| expression: torch.tensor, optional, shape BxN_e |
| Expression coefficients. |
| For example, it can used if expression parameters |
| `expression` are predicted from some external model. |
| body_pose: torch.tensor, optional, shape BxJx3x3 |
| If given, ignore the member variable `body_pose` and use it |
| instead. For example, it can used if someone predicts the |
| pose of the body joints are predicted from some external model. |
| It should be a tensor that contains joint rotations in |
| rotation matrix format. (default=None) |
| left_hand_pose: torch.tensor, optional, shape Bx15x3x3 |
| If given, contains the pose of the left hand. |
| It should be a tensor that contains joint rotations in |
| rotation matrix format. (default=None) |
| right_hand_pose: torch.tensor, optional, shape Bx15x3x3 |
| If given, contains the pose of the right hand. |
| It should be a tensor that contains joint rotations in |
| rotation matrix format. (default=None) |
| jaw_pose: torch.tensor, optional, shape Bx3x3 |
| Jaw pose. It should either joint rotations in |
| rotation matrix format. |
| transl: torch.tensor, optional, shape Bx3 |
| Translation vector of the body. |
| For example, it can used if the translation |
| `transl` is predicted from some external model. |
| (default=None) |
| return_verts: bool, optional |
| Return the vertices. (default=True) |
| return_full_pose: bool, optional |
| Returns the full pose vector (default=False) |
| Returns |
| ------- |
| output: ModelOutput |
| A data class that contains the posed vertices and joints |
| ''' |
| device, dtype = self.shapedirs.device, self.shapedirs.dtype |
|
|
| model_vars = [global_orient, body_pose, left_hand_pose, right_hand_pose, jaw_pose] |
| batch_size = 1 |
| for var in model_vars: |
| if var is None: |
| continue |
| batch_size = max(batch_size, len(var)) |
|
|
| if global_orient is None: |
| global_orient = torch.eye(3, device=device, |
| dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, |
| -1).contiguous() |
| if body_pose is None: |
| body_pose = torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand( |
| batch_size, self.NUM_BODY_JOINTS, -1, -1 |
| ).contiguous() |
| if left_hand_pose is None: |
| left_hand_pose = torch.eye(3, device=device, |
| dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1, |
| -1).contiguous() |
| if right_hand_pose is None: |
| right_hand_pose = torch.eye(3, device=device, |
| dtype=dtype).view(1, 1, 3, |
| 3).expand(batch_size, 15, -1, |
| -1).contiguous() |
| if jaw_pose is None: |
| jaw_pose = torch.eye(3, device=device, |
| dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, |
| -1).contiguous() |
| if leye_pose is None: |
| leye_pose = torch.eye(3, device=device, |
| dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, |
| -1).contiguous() |
| if reye_pose is None: |
| reye_pose = torch.eye(3, device=device, |
| dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, |
| -1).contiguous() |
|
|
| |
| full_pose = torch.cat( |
| [ |
| global_orient.reshape(-1, 1, 3, 3), |
| body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3), |
| jaw_pose.reshape(-1, 1, 3, 3), |
| leye_pose.reshape(-1, 1, 3, 3), |
| reye_pose.reshape(-1, 1, 3, 3), |
| left_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3), |
| right_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3) |
| ], |
| dim=1 |
| ) |
|
|
| rot_mats = full_pose.view(batch_size, -1, 3, 3) |
|
|
| |
| |
| joints = vertices2joints( |
| self.J_regressor, |
| self.v_template.unsqueeze(0).expand(batch_size, -1, -1) |
| ) |
|
|
| joints = torch.unsqueeze(joints, dim=-1) |
|
|
| rel_joints = joints.clone() |
| rel_joints[:, 1:] -= joints[:, self.parents[1:]] |
|
|
| transforms_mat = transform_mat(rot_mats.reshape(-1, 3, 3), |
| rel_joints.reshape(-1, 3, |
| 1)).reshape(-1, joints.shape[1], 4, 4) |
|
|
| transform_chain = [transforms_mat[:, 0]] |
| for i in range(1, self.parents.shape[0]): |
| |
| |
| curr_res = torch.matmul(transform_chain[self.parents[i]], transforms_mat[:, i]) |
| transform_chain.append(curr_res) |
|
|
| transforms = torch.stack(transform_chain, dim=1) |
|
|
| global_rotmat = transforms[:, :, :3, :3] |
|
|
| |
| posed_joints = transforms[:, :, :3, 3] |
|
|
| return global_rotmat, posed_joints |
|
|
|
|
| class SMPLX_ALL(nn.Module): |
| """ Extension of the official SMPLX implementation to support more joints """ |
| def __init__(self, batch_size=1, use_face_contour=True, all_gender=False, **kwargs): |
| super().__init__() |
| numBetas = 10 |
| self.use_face_contour = use_face_contour |
| if all_gender: |
| self.genders = ['male', 'female', 'neutral'] |
| else: |
| self.genders = ['neutral'] |
| for gender in self.genders: |
| assert gender in ['male', 'female', 'neutral'] |
| self.model_dict = nn.ModuleDict( |
| { |
| gender: SMPLX( |
| path_config.SMPL_MODEL_DIR, |
| gender=gender, |
| ext='npz', |
| num_betas=numBetas, |
| use_pca=False, |
| batch_size=batch_size, |
| use_face_contour=use_face_contour, |
| num_pca_comps=45, |
| **kwargs |
| ) |
| for gender in self.genders |
| } |
| ) |
| self.model_neutral = self.model_dict['neutral'] |
| joints = [constants.JOINT_MAP[i] for i in constants.JOINT_NAMES] |
| J_regressor_extra = np.load(path_config.JOINT_REGRESSOR_TRAIN_EXTRA) |
| self.register_buffer( |
| 'J_regressor_extra', torch.tensor(J_regressor_extra, dtype=torch.float32) |
| ) |
| self.joint_map = torch.tensor(joints, dtype=torch.long) |
| |
| smplx_to_smpl = pickle.load( |
| open(os.path.join(SMPL_MODEL_DIR, 'model_transfer/smplx_to_smpl.pkl'), 'rb') |
| ) |
| self.register_buffer( |
| 'smplx2smpl', torch.tensor(smplx_to_smpl['matrix'][None], dtype=torch.float32) |
| ) |
|
|
| smpl2limb_vert_faces = get_partial_smpl('smpl') |
| self.smpl2lhand = torch.from_numpy(smpl2limb_vert_faces['lhand']['vids']).long() |
| self.smpl2rhand = torch.from_numpy(smpl2limb_vert_faces['rhand']['vids']).long() |
|
|
| |
| smplx2lhand_joints = [ |
| constants.SMPLX_JOINT_IDS['left_{}'.format(name)] for name in constants.HAND_NAMES |
| ] |
| smplx2rhand_joints = [ |
| constants.SMPLX_JOINT_IDS['right_{}'.format(name)] for name in constants.HAND_NAMES |
| ] |
| self.smplx2lh_joint_map = torch.tensor(smplx2lhand_joints, dtype=torch.long) |
| self.smplx2rh_joint_map = torch.tensor(smplx2rhand_joints, dtype=torch.long) |
|
|
| |
| smplx2lfoot_joints = [ |
| constants.SMPLX_JOINT_IDS['left_{}'.format(name)] for name in constants.FOOT_NAMES |
| ] |
| smplx2rfoot_joints = [ |
| constants.SMPLX_JOINT_IDS['right_{}'.format(name)] for name in constants.FOOT_NAMES |
| ] |
| self.smplx2lf_joint_map = torch.tensor(smplx2lfoot_joints, dtype=torch.long) |
| self.smplx2rf_joint_map = torch.tensor(smplx2rfoot_joints, dtype=torch.long) |
|
|
| for g in self.genders: |
| J_template = torch.einsum( |
| 'ji,ik->jk', [self.model_dict[g].J_regressor[:24], self.model_dict[g].v_template] |
| ) |
| J_dirs = torch.einsum( |
| 'ji,ikl->jkl', [self.model_dict[g].J_regressor[:24], self.model_dict[g].shapedirs] |
| ) |
|
|
| self.register_buffer(f'{g}_J_template', J_template) |
| self.register_buffer(f'{g}_J_dirs', J_dirs) |
|
|
| def forward(self, *args, **kwargs): |
| batch_size = kwargs['body_pose'].shape[0] |
| kwargs['get_skin'] = True |
| if 'pose2rot' not in kwargs: |
| kwargs['pose2rot'] = True |
| if 'gender' not in kwargs: |
| kwargs['gender'] = 2 * torch.ones(batch_size).to(kwargs['body_pose'].device) |
|
|
| |
| pose_keys = [ |
| 'global_orient', 'body_pose', 'left_hand_pose', 'right_hand_pose', 'jaw_pose', |
| 'leye_pose', 'reye_pose' |
| ] |
| param_keys = ['betas'] + pose_keys |
| if kwargs['pose2rot']: |
| for key in pose_keys: |
| if key in kwargs: |
| |
| |
| |
| |
| kwargs[key] = batch_rodrigues(kwargs[key].contiguous().view(-1, 3)).view( |
| [batch_size, -1, 3, 3] |
| ) |
| if kwargs['body_pose'].shape[1] == 23: |
| |
| kwargs['body_pose'] = kwargs['body_pose'][:, :21] |
| gender_idx_list = [] |
| smplx_vertices, smplx_joints = [], [] |
| for gi, g in enumerate(['male', 'female', 'neutral']): |
| gender_idx = ((kwargs['gender'] == gi).nonzero(as_tuple=True)[0]) |
| if len(gender_idx) == 0: |
| continue |
| gender_idx_list.extend([int(idx) for idx in gender_idx]) |
| gender_kwargs = {'get_skin': kwargs['get_skin'], 'pose2rot': kwargs['pose2rot']} |
| gender_kwargs.update({k: kwargs[k][gender_idx] for k in param_keys if k in kwargs}) |
| gender_smplx_output = self.model_dict[g].forward(*args, **gender_kwargs) |
| smplx_vertices.append(gender_smplx_output.vertices) |
| smplx_joints.append(gender_smplx_output.joints) |
|
|
| idx_rearrange = [gender_idx_list.index(i) for i in range(len(list(gender_idx_list)))] |
| idx_rearrange = torch.tensor(idx_rearrange).long().to(kwargs['body_pose'].device) |
|
|
| smplx_vertices = torch.cat(smplx_vertices)[idx_rearrange] |
| smplx_joints = torch.cat(smplx_joints)[idx_rearrange] |
|
|
| |
| lhand_joints = smplx_joints[:, self.smplx2lh_joint_map] |
| rhand_joints = smplx_joints[:, self.smplx2rh_joint_map] |
| |
| face_joints = smplx_joints[:, -68:] if self.use_face_contour else smplx_joints[:, -51:] |
| |
| lfoot_joints = smplx_joints[:, self.smplx2lf_joint_map] |
| rfoot_joints = smplx_joints[:, self.smplx2rf_joint_map] |
|
|
| smpl_vertices = torch.bmm(self.smplx2smpl.expand(batch_size, -1, -1), smplx_vertices) |
| lhand_vertices = smpl_vertices[:, self.smpl2lhand] |
| rhand_vertices = smpl_vertices[:, self.smpl2rhand] |
| extra_joints = vertices2joints(self.J_regressor_extra, smpl_vertices) |
| |
| smplx_j45 = smplx_joints[:, constants.SMPLX2SMPL_J45] |
| joints = torch.cat([smplx_j45, extra_joints], dim=1) |
| smpl_joints = smplx_j45[:, :24] |
| joints = joints[:, self.joint_map, :] |
| joints_J24 = joints[:, -24:, :] |
| joints_J19 = joints_J24[:, constants.J24_TO_J19, :] |
| output = ModelOutput( |
| vertices=smpl_vertices, |
| smplx_vertices=smplx_vertices, |
| lhand_vertices=lhand_vertices, |
| rhand_vertices=rhand_vertices, |
| |
| |
| joints=joints, |
| joints_J19=joints_J19, |
| smpl_joints=smpl_joints, |
| |
| |
| lhand_joints=lhand_joints, |
| rhand_joints=rhand_joints, |
| lfoot_joints=lfoot_joints, |
| rfoot_joints=rfoot_joints, |
| face_joints=face_joints, |
| ) |
| return output |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def get_tpose(self, betas=None, gender=None): |
| kwargs = {} |
| if betas is None: |
| betas = torch.zeros(1, 10).to(self.J_regressor_extra.device) |
| kwargs['betas'] = betas |
|
|
| batch_size = kwargs['betas'].shape[0] |
| device = kwargs['betas'].device |
|
|
| if gender is None: |
| kwargs['gender'] = 2 * torch.ones(batch_size).to(device) |
| else: |
| kwargs['gender'] = gender |
|
|
| param_keys = ['betas'] |
|
|
| gender_idx_list = [] |
| smplx_joints = [] |
| for gi, g in enumerate(['male', 'female', 'neutral']): |
| gender_idx = ((kwargs['gender'] == gi).nonzero(as_tuple=True)[0]) |
| if len(gender_idx) == 0: |
| continue |
| gender_idx_list.extend([int(idx) for idx in gender_idx]) |
| gender_kwargs = {} |
| gender_kwargs.update({k: kwargs[k][gender_idx] for k in param_keys if k in kwargs}) |
|
|
| J = getattr(self, f'{g}_J_template').unsqueeze(0) + blend_shapes( |
| gender_kwargs['betas'], getattr(self, f'{g}_J_dirs') |
| ) |
|
|
| smplx_joints.append(J) |
|
|
| idx_rearrange = [gender_idx_list.index(i) for i in range(len(list(gender_idx_list)))] |
| idx_rearrange = torch.tensor(idx_rearrange).long().to(device) |
|
|
| smplx_joints = torch.cat(smplx_joints)[idx_rearrange] |
|
|
| return smplx_joints |
|
|
|
|
| class MANO(MANOLayer): |
| """ Extension of the official MANO implementation to support more joints """ |
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
|
|
| def forward(self, *args, **kwargs): |
| if 'pose2rot' not in kwargs: |
| kwargs['pose2rot'] = True |
| pose_keys = ['global_orient', 'right_hand_pose'] |
| batch_size = kwargs['global_orient'].shape[0] |
| if kwargs['pose2rot']: |
| for key in pose_keys: |
| if key in kwargs: |
| kwargs[key] = batch_rodrigues(kwargs[key].contiguous().view(-1, 3)).view( |
| [batch_size, -1, 3, 3] |
| ) |
| kwargs['hand_pose'] = kwargs.pop('right_hand_pose') |
| mano_output = super().forward(*args, **kwargs) |
| th_verts = mano_output.vertices |
| th_jtr = mano_output.joints |
| |
| |
| |
| tips = th_verts[:, [745, 317, 445, 556, 673]] |
| th_jtr = torch.cat([th_jtr, tips], 1) |
| |
| th_jtr = th_jtr[:, |
| [0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20]] |
| output = ModelOutput( |
| rhand_vertices=th_verts, |
| rhand_joints=th_jtr, |
| ) |
| return output |
|
|
|
|
| class FLAME(FLAMELayer): |
| """ Extension of the official FLAME implementation to support more joints """ |
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
|
|
| def forward(self, *args, **kwargs): |
| if 'pose2rot' not in kwargs: |
| kwargs['pose2rot'] = True |
| pose_keys = ['global_orient', 'jaw_pose', 'leye_pose', 'reye_pose'] |
| batch_size = kwargs['global_orient'].shape[0] |
| if kwargs['pose2rot']: |
| for key in pose_keys: |
| if key in kwargs: |
| kwargs[key] = batch_rodrigues(kwargs[key].contiguous().view(-1, 3)).view( |
| [batch_size, -1, 3, 3] |
| ) |
| flame_output = super().forward(*args, **kwargs) |
| output = ModelOutput( |
| flame_vertices=flame_output.vertices, |
| face_joints=flame_output.joints[:, 5:], |
| ) |
| return output |
|
|
|
|
| class SMPL_Family(): |
| def __init__(self, model_type='smpl', *args, **kwargs): |
| if model_type == 'smpl': |
| self.model = SMPL(model_path=SMPL_MODEL_DIR, *args, **kwargs) |
| elif model_type == 'smplx': |
| self.model = SMPLX_ALL(*args, **kwargs) |
| elif model_type == 'mano': |
| self.model = MANO( |
| model_path=SMPL_MODEL_DIR, is_rhand=True, use_pca=False, *args, **kwargs |
| ) |
| elif model_type == 'flame': |
| self.model = FLAME(model_path=SMPL_MODEL_DIR, use_face_contour=True, *args, **kwargs) |
|
|
| def __call__(self, *args, **kwargs): |
| return self.model(*args, **kwargs) |
|
|
| def get_tpose(self, *args, **kwargs): |
| return self.model.get_tpose(*args, **kwargs) |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
|
|
| def get_smpl_faces(): |
| smpl = SMPL(model_path=SMPL_MODEL_DIR, batch_size=1) |
| return smpl.faces |
|
|
|
|
| def get_smplx_faces(): |
| smplx = SMPLX(SMPL_MODEL_DIR, batch_size=1) |
| return smplx.faces |
|
|
|
|
| def get_mano_faces(hand_type='right'): |
| assert hand_type in ['right', 'left'] |
| is_rhand = True if hand_type == 'right' else False |
| mano = MANO(SMPL_MODEL_DIR, batch_size=1, is_rhand=is_rhand) |
|
|
| return mano.faces |
|
|
|
|
| def get_flame_faces(): |
| flame = FLAME(SMPL_MODEL_DIR, batch_size=1) |
|
|
| return flame.faces |
|
|
|
|
| def get_model_faces(type='smpl'): |
| if type == 'smpl': |
| return get_smpl_faces() |
| elif type == 'smplx': |
| return get_smplx_faces() |
| elif type == 'mano': |
| return get_mano_faces() |
| elif type == 'flame': |
| return get_flame_faces() |
|
|
|
|
| def get_model_tpose(type='smpl'): |
| if type == 'smpl': |
| return get_smpl_tpose() |
| elif type == 'smplx': |
| return get_smplx_tpose() |
| elif type == 'mano': |
| return get_mano_tpose() |
| elif type == 'flame': |
| return get_flame_tpose() |
|
|
|
|
| def get_smpl_tpose(): |
| smpl = SMPL( |
| create_betas=True, |
| create_global_orient=True, |
| create_body_pose=True, |
| model_path=SMPL_MODEL_DIR, |
| batch_size=1 |
| ) |
| vertices = smpl().vertices[0] |
| return vertices.detach() |
|
|
|
|
| def get_smpl_tpose_joint(): |
| smpl = SMPL( |
| create_betas=True, |
| create_global_orient=True, |
| create_body_pose=True, |
| model_path=SMPL_MODEL_DIR, |
| batch_size=1 |
| ) |
| tpose_joint = smpl().smpl_joints[0] |
| return tpose_joint.detach() |
|
|
|
|
| def get_smplx_tpose(): |
| smplx = SMPLXLayer(SMPL_MODEL_DIR, batch_size=1) |
| vertices = smplx().vertices[0] |
| return vertices |
|
|
|
|
| def get_smplx_tpose_joint(): |
| smplx = SMPLXLayer(SMPL_MODEL_DIR, batch_size=1) |
| tpose_joint = smplx().joints[0] |
| return tpose_joint |
|
|
|
|
| def get_mano_tpose(): |
| mano = MANO(SMPL_MODEL_DIR, batch_size=1, is_rhand=True) |
| vertices = mano(global_orient=torch.zeros(1, 3), |
| right_hand_pose=torch.zeros(1, 15 * 3)).rhand_vertices[0] |
| return vertices |
|
|
|
|
| def get_flame_tpose(): |
| flame = FLAME(SMPL_MODEL_DIR, batch_size=1) |
| vertices = flame(global_orient=torch.zeros(1, 3)).flame_vertices[0] |
| return vertices |
|
|
|
|
| def get_part_joints(smpl_joints): |
| batch_size = smpl_joints.shape[0] |
|
|
| |
|
|
| one_seg_pairs = [ |
| (0, 1), (0, 2), (0, 3), (3, 6), (9, 12), (9, 13), (9, 14), (12, 15), (13, 16), (14, 17) |
| ] |
| two_seg_pairs = [(1, 4), (2, 5), (4, 7), (5, 8), (16, 18), (17, 19), (18, 20), (19, 21)] |
|
|
| one_seg_pairs.extend(two_seg_pairs) |
|
|
| single_joints = [(10), (11), (15), (22), (23)] |
|
|
| part_joints = [] |
|
|
| for j_p in one_seg_pairs: |
| new_joint = torch.mean(smpl_joints[:, j_p], dim=1, keepdim=True) |
| part_joints.append(new_joint) |
|
|
| for j_p in single_joints: |
| part_joints.append(smpl_joints[:, j_p:j_p + 1]) |
|
|
| part_joints = torch.cat(part_joints, dim=1) |
|
|
| return part_joints |
|
|
|
|
| def get_partial_smpl(body_model='smpl', device=torch.device('cuda')): |
|
|
| body_model_faces = get_model_faces(body_model) |
| body_model_num_verts = len(get_model_tpose(body_model)) |
|
|
| part_vert_faces = {} |
|
|
| for part in ['lhand', 'rhand', 'face', 'arm', 'forearm', 'larm', 'rarm', 'lwrist', 'rwrist']: |
| part_vid_fname = '{}/{}_{}_vids.npz'.format(path_config.PARTIAL_MESH_DIR, body_model, part) |
| if os.path.exists(part_vid_fname): |
| part_vids = np.load(part_vid_fname) |
| part_vert_faces[part] = {'vids': part_vids['vids'], 'faces': part_vids['faces']} |
| else: |
| if part in ['lhand', 'rhand']: |
| with open( |
| os.path.join(SMPL_MODEL_DIR, 'model_transfer/MANO_SMPLX_vertex_ids.pkl'), 'rb' |
| ) as json_file: |
| smplx_mano_id = pickle.load(json_file) |
| with open( |
| os.path.join(SMPL_MODEL_DIR, 'model_transfer/smplx_to_smpl.pkl'), 'rb' |
| ) as json_file: |
| smplx_smpl_id = pickle.load(json_file) |
|
|
| smplx_tpose = get_smplx_tpose() |
| smpl_tpose = np.matmul(smplx_smpl_id['matrix'], smplx_tpose) |
|
|
| if part == 'lhand': |
| mano_vert = smplx_tpose[smplx_mano_id['left_hand']] |
| elif part == 'rhand': |
| mano_vert = smplx_tpose[smplx_mano_id['right_hand']] |
|
|
| smpl2mano_id = [] |
| for vert in mano_vert: |
| v_diff = smpl_tpose - vert |
| v_diff = torch.sum(v_diff * v_diff, dim=1) |
| v_closest = torch.argmin(v_diff) |
| smpl2mano_id.append(int(v_closest)) |
|
|
| smpl2mano_vids = np.array(smpl2mano_id).astype(np.longlong) |
| mano_faces = get_mano_faces(hand_type='right' if part == 'rhand' else 'left' |
| ).astype(np.longlong) |
|
|
| np.savez(part_vid_fname, vids=smpl2mano_vids, faces=mano_faces) |
| part_vert_faces[part] = {'vids': smpl2mano_vids, 'faces': mano_faces} |
|
|
| elif part in ['face', 'arm', 'forearm', 'larm', 'rarm']: |
| with open( |
| os.path.join(SMPL_MODEL_DIR, '{}_vert_segmentation.json'.format(body_model)), |
| 'rb' |
| ) as json_file: |
| smplx_part_id = json.load(json_file) |
|
|
| |
| |
|
|
| if part == 'face': |
| selected_body_part = ['head'] |
| elif part == 'arm': |
| selected_body_part = [ |
| 'rightHand', |
| 'leftArm', |
| 'leftShoulder', |
| 'rightShoulder', |
| 'rightArm', |
| 'leftHandIndex1', |
| 'rightHandIndex1', |
| 'leftForeArm', |
| 'rightForeArm', |
| 'leftHand', |
| ] |
| |
| elif part == 'forearm': |
| selected_body_part = [ |
| 'rightHand', |
| 'leftHandIndex1', |
| 'rightHandIndex1', |
| 'leftForeArm', |
| 'rightForeArm', |
| 'leftHand', |
| ] |
| elif part == 'arm_eval': |
| selected_body_part = ['leftArm', 'rightArm', 'leftForeArm', 'rightForeArm'] |
| elif part == 'larm': |
| |
| selected_body_part = ['leftForeArm'] |
| elif part == 'rarm': |
| |
| selected_body_part = ['rightForeArm'] |
|
|
| part_body_idx = [] |
| for k in selected_body_part: |
| part_body_idx.extend(smplx_part_id[k]) |
|
|
| part_body_fid = [] |
| for f_id, face in enumerate(body_model_faces): |
| if any(f in part_body_idx for f in face): |
| part_body_fid.append(f_id) |
|
|
| smpl2head_vids = np.unique(body_model_faces[part_body_fid]).astype(np.longlong) |
|
|
| mesh_vid_raw = np.arange(body_model_num_verts) |
| head_vid_new = np.arange(len(smpl2head_vids)) |
| mesh_vid_raw[smpl2head_vids] = head_vid_new |
|
|
| head_faces = body_model_faces[part_body_fid] |
| head_faces = mesh_vid_raw[head_faces].astype(np.longlong) |
|
|
| np.savez(part_vid_fname, vids=smpl2head_vids, faces=head_faces) |
| part_vert_faces[part] = {'vids': smpl2head_vids, 'faces': head_faces} |
|
|
| elif part in ['lwrist', 'rwrist']: |
|
|
| if body_model == 'smplx': |
| body_model_verts = get_smplx_tpose() |
| tpose_joint = get_smplx_tpose_joint() |
| elif body_model == 'smpl': |
| body_model_verts = get_smpl_tpose() |
| tpose_joint = get_smpl_tpose_joint() |
|
|
| wrist_joint = tpose_joint[20] if part == 'lwrist' else tpose_joint[21] |
|
|
| dist = 0.005 |
| wrist_vids = [] |
| for vid, vt in enumerate(body_model_verts): |
|
|
| v_j_dist = torch.sum((vt - wrist_joint)**2) |
|
|
| if v_j_dist < dist: |
| wrist_vids.append(vid) |
|
|
| wrist_vids = np.array(wrist_vids) |
|
|
| part_body_fid = [] |
| for f_id, face in enumerate(body_model_faces): |
| if any(f in wrist_vids for f in face): |
| part_body_fid.append(f_id) |
|
|
| smpl2part_vids = np.unique(body_model_faces[part_body_fid]).astype(np.longlong) |
|
|
| mesh_vid_raw = np.arange(body_model_num_verts) |
| part_vid_new = np.arange(len(smpl2part_vids)) |
| mesh_vid_raw[smpl2part_vids] = part_vid_new |
|
|
| part_faces = body_model_faces[part_body_fid] |
| part_faces = mesh_vid_raw[part_faces].astype(np.longlong) |
|
|
| np.savez(part_vid_fname, vids=smpl2part_vids, faces=part_faces) |
| part_vert_faces[part] = {'vids': smpl2part_vids, 'faces': part_faces} |
|
|
| |
| |
| |
|
|
| |
| |
|
|
| return part_vert_faces |
|
|