| | |
| |
|
| | from typing import Optional |
| | from dataclasses import dataclass |
| |
|
| | import os |
| | import torch |
| | import torch.nn as nn |
| | import numpy as np |
| | import pickle |
| | from lib.smplx import SMPL as _SMPL |
| | from lib.smplx import SMPLXLayer, MANOLayer, FLAMELayer |
| | from lib.smplx.lbs import batch_rodrigues, transform_mat, vertices2joints, blend_shapes |
| | from lib.smplx.body_models import SMPLXOutput |
| | import json |
| |
|
| | from lib.pymafx.core import path_config, constants |
| |
|
| | SMPL_MEAN_PARAMS = path_config.SMPL_MEAN_PARAMS |
| | SMPL_MODEL_DIR = path_config.SMPL_MODEL_DIR |
| |
|
| |
|
| | @dataclass |
| | class ModelOutput(SMPLXOutput): |
| | smpl_joints: Optional[torch.Tensor] = None |
| | joints_J19: Optional[torch.Tensor] = None |
| | smplx_vertices: Optional[torch.Tensor] = None |
| | flame_vertices: Optional[torch.Tensor] = None |
| | lhand_vertices: Optional[torch.Tensor] = None |
| | rhand_vertices: Optional[torch.Tensor] = None |
| | lhand_joints: Optional[torch.Tensor] = None |
| | rhand_joints: Optional[torch.Tensor] = None |
| | face_joints: Optional[torch.Tensor] = None |
| | lfoot_joints: Optional[torch.Tensor] = None |
| | rfoot_joints: Optional[torch.Tensor] = None |
| |
|
| |
|
| | class SMPL(_SMPL): |
| | """ Extension of the official SMPL implementation to support more joints """ |
| | def __init__( |
| | self, |
| | create_betas=False, |
| | create_global_orient=False, |
| | create_body_pose=False, |
| | create_transl=False, |
| | *args, |
| | **kwargs |
| | ): |
| | super().__init__( |
| | create_betas=create_betas, |
| | create_global_orient=create_global_orient, |
| | create_body_pose=create_body_pose, |
| | create_transl=create_transl, |
| | *args, |
| | **kwargs |
| | ) |
| | joints = [constants.JOINT_MAP[i] for i in constants.JOINT_NAMES] |
| | J_regressor_extra = np.load(path_config.JOINT_REGRESSOR_TRAIN_EXTRA) |
| | self.register_buffer( |
| | 'J_regressor_extra', torch.tensor(J_regressor_extra, dtype=torch.float32) |
| | ) |
| | self.joint_map = torch.tensor(joints, dtype=torch.long) |
| | |
| | |
| |
|
| | tpose_joints = vertices2joints(self.J_regressor, self.v_template.unsqueeze(0)) |
| | self.register_buffer('tpose_joints', tpose_joints) |
| |
|
| | def forward(self, *args, **kwargs): |
| | kwargs['get_skin'] = True |
| | smpl_output = super().forward(*args, **kwargs) |
| | extra_joints = vertices2joints(self.J_regressor_extra, smpl_output.vertices) |
| | |
| | vertices = smpl_output.vertices |
| | joints = torch.cat([smpl_output.joints, extra_joints], dim=1) |
| | smpl_joints = smpl_output.joints[:, :24] |
| | joints = joints[:, self.joint_map, :] |
| | joints_J24 = joints[:, -24:, :] |
| | joints_J19 = joints_J24[:, constants.J24_TO_J19, :] |
| | output = ModelOutput( |
| | vertices=vertices, |
| | global_orient=smpl_output.global_orient, |
| | body_pose=smpl_output.body_pose, |
| | joints=joints, |
| | joints_J19=joints_J19, |
| | smpl_joints=smpl_joints, |
| | betas=smpl_output.betas, |
| | full_pose=smpl_output.full_pose |
| | ) |
| | return output |
| |
|
| | def get_global_rotation( |
| | self, |
| | global_orient: Optional[torch.Tensor] = None, |
| | body_pose: Optional[torch.Tensor] = None, |
| | **kwargs |
| | ): |
| | ''' |
| | Forward pass for the SMPLX model |
| | |
| | Parameters |
| | ---------- |
| | global_orient: torch.tensor, optional, shape Bx3x3 |
| | If given, ignore the member variable and use it as the global |
| | rotation of the body. Useful if someone wishes to predicts this |
| | with an external model. It is expected to be in rotation matrix |
| | format. (default=None) |
| | body_pose: torch.tensor, optional, shape BxJx3x3 |
| | If given, ignore the member variable `body_pose` and use it |
| | instead. For example, it can used if someone predicts the |
| | pose of the body joints are predicted from some external model. |
| | It should be a tensor that contains joint rotations in |
| | rotation matrix format. (default=None) |
| | Returns |
| | ------- |
| | output: Global rotation matrix |
| | ''' |
| | device, dtype = self.shapedirs.device, self.shapedirs.dtype |
| |
|
| | model_vars = [global_orient, body_pose] |
| | batch_size = 1 |
| | for var in model_vars: |
| | if var is None: |
| | continue |
| | batch_size = max(batch_size, len(var)) |
| |
|
| | if global_orient is None: |
| | global_orient = torch.eye(3, device=device, |
| | dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, |
| | -1).contiguous() |
| | if body_pose is None: |
| | body_pose = torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand( |
| | batch_size, self.NUM_BODY_JOINTS, -1, -1 |
| | ).contiguous() |
| |
|
| | |
| | full_pose = torch.cat( |
| | [global_orient.reshape(-1, 1, 3, 3), |
| | body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3)], |
| | dim=1 |
| | ) |
| |
|
| | rot_mats = full_pose.view(batch_size, -1, 3, 3) |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | joints = self.tpose_joints.expand(batch_size, -1, -1).unsqueeze(-1) |
| |
|
| | rel_joints = joints.clone() |
| | rel_joints[:, 1:] -= joints[:, self.parents[1:]] |
| |
|
| | transforms_mat = transform_mat(rot_mats.reshape(-1, 3, 3), |
| | rel_joints.reshape(-1, 3, |
| | 1)).reshape(-1, joints.shape[1], 4, 4) |
| |
|
| | transform_chain = [transforms_mat[:, 0]] |
| | for i in range(1, self.parents.shape[0]): |
| | |
| | |
| | curr_res = torch.matmul(transform_chain[self.parents[i]], transforms_mat[:, i]) |
| | transform_chain.append(curr_res) |
| |
|
| | transforms = torch.stack(transform_chain, dim=1) |
| |
|
| | global_rotmat = transforms[:, :, :3, :3] |
| |
|
| | |
| | posed_joints = transforms[:, :, :3, 3] |
| |
|
| | return global_rotmat, posed_joints |
| |
|
| |
|
| | class SMPLX(SMPLXLayer): |
| | """ Extension of the official SMPLX implementation to support more functions """ |
| | def __init__(self, *args, **kwargs): |
| | super().__init__(*args, **kwargs) |
| |
|
| | def get_global_rotation( |
| | self, |
| | global_orient: Optional[torch.Tensor] = None, |
| | body_pose: Optional[torch.Tensor] = None, |
| | left_hand_pose: Optional[torch.Tensor] = None, |
| | right_hand_pose: Optional[torch.Tensor] = None, |
| | jaw_pose: Optional[torch.Tensor] = None, |
| | leye_pose: Optional[torch.Tensor] = None, |
| | reye_pose: Optional[torch.Tensor] = None, |
| | **kwargs |
| | ): |
| | ''' |
| | Forward pass for the SMPLX model |
| | |
| | Parameters |
| | ---------- |
| | global_orient: torch.tensor, optional, shape Bx3x3 |
| | If given, ignore the member variable and use it as the global |
| | rotation of the body. Useful if someone wishes to predicts this |
| | with an external model. It is expected to be in rotation matrix |
| | format. (default=None) |
| | betas: torch.tensor, optional, shape BxN_b |
| | If given, ignore the member variable `betas` and use it |
| | instead. For example, it can used if shape parameters |
| | `betas` are predicted from some external model. |
| | (default=None) |
| | expression: torch.tensor, optional, shape BxN_e |
| | Expression coefficients. |
| | For example, it can used if expression parameters |
| | `expression` are predicted from some external model. |
| | body_pose: torch.tensor, optional, shape BxJx3x3 |
| | If given, ignore the member variable `body_pose` and use it |
| | instead. For example, it can used if someone predicts the |
| | pose of the body joints are predicted from some external model. |
| | It should be a tensor that contains joint rotations in |
| | rotation matrix format. (default=None) |
| | left_hand_pose: torch.tensor, optional, shape Bx15x3x3 |
| | If given, contains the pose of the left hand. |
| | It should be a tensor that contains joint rotations in |
| | rotation matrix format. (default=None) |
| | right_hand_pose: torch.tensor, optional, shape Bx15x3x3 |
| | If given, contains the pose of the right hand. |
| | It should be a tensor that contains joint rotations in |
| | rotation matrix format. (default=None) |
| | jaw_pose: torch.tensor, optional, shape Bx3x3 |
| | Jaw pose. It should either joint rotations in |
| | rotation matrix format. |
| | transl: torch.tensor, optional, shape Bx3 |
| | Translation vector of the body. |
| | For example, it can used if the translation |
| | `transl` is predicted from some external model. |
| | (default=None) |
| | return_verts: bool, optional |
| | Return the vertices. (default=True) |
| | return_full_pose: bool, optional |
| | Returns the full pose vector (default=False) |
| | Returns |
| | ------- |
| | output: ModelOutput |
| | A data class that contains the posed vertices and joints |
| | ''' |
| | device, dtype = self.shapedirs.device, self.shapedirs.dtype |
| |
|
| | model_vars = [global_orient, body_pose, left_hand_pose, right_hand_pose, jaw_pose] |
| | batch_size = 1 |
| | for var in model_vars: |
| | if var is None: |
| | continue |
| | batch_size = max(batch_size, len(var)) |
| |
|
| | if global_orient is None: |
| | global_orient = torch.eye(3, device=device, |
| | dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, |
| | -1).contiguous() |
| | if body_pose is None: |
| | body_pose = torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand( |
| | batch_size, self.NUM_BODY_JOINTS, -1, -1 |
| | ).contiguous() |
| | if left_hand_pose is None: |
| | left_hand_pose = torch.eye(3, device=device, |
| | dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1, |
| | -1).contiguous() |
| | if right_hand_pose is None: |
| | right_hand_pose = torch.eye(3, device=device, |
| | dtype=dtype).view(1, 1, 3, |
| | 3).expand(batch_size, 15, -1, |
| | -1).contiguous() |
| | if jaw_pose is None: |
| | jaw_pose = torch.eye(3, device=device, |
| | dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, |
| | -1).contiguous() |
| | if leye_pose is None: |
| | leye_pose = torch.eye(3, device=device, |
| | dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, |
| | -1).contiguous() |
| | if reye_pose is None: |
| | reye_pose = torch.eye(3, device=device, |
| | dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, |
| | -1).contiguous() |
| |
|
| | |
| | full_pose = torch.cat( |
| | [ |
| | global_orient.reshape(-1, 1, 3, 3), |
| | body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3), |
| | jaw_pose.reshape(-1, 1, 3, 3), |
| | leye_pose.reshape(-1, 1, 3, 3), |
| | reye_pose.reshape(-1, 1, 3, 3), |
| | left_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3), |
| | right_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3) |
| | ], |
| | dim=1 |
| | ) |
| |
|
| | rot_mats = full_pose.view(batch_size, -1, 3, 3) |
| |
|
| | |
| | |
| | joints = vertices2joints( |
| | self.J_regressor, |
| | self.v_template.unsqueeze(0).expand(batch_size, -1, -1) |
| | ) |
| |
|
| | joints = torch.unsqueeze(joints, dim=-1) |
| |
|
| | rel_joints = joints.clone() |
| | rel_joints[:, 1:] -= joints[:, self.parents[1:]] |
| |
|
| | transforms_mat = transform_mat(rot_mats.reshape(-1, 3, 3), |
| | rel_joints.reshape(-1, 3, |
| | 1)).reshape(-1, joints.shape[1], 4, 4) |
| |
|
| | transform_chain = [transforms_mat[:, 0]] |
| | for i in range(1, self.parents.shape[0]): |
| | |
| | |
| | curr_res = torch.matmul(transform_chain[self.parents[i]], transforms_mat[:, i]) |
| | transform_chain.append(curr_res) |
| |
|
| | transforms = torch.stack(transform_chain, dim=1) |
| |
|
| | global_rotmat = transforms[:, :, :3, :3] |
| |
|
| | |
| | posed_joints = transforms[:, :, :3, 3] |
| |
|
| | return global_rotmat, posed_joints |
| |
|
| |
|
| | class SMPLX_ALL(nn.Module): |
| | """ Extension of the official SMPLX implementation to support more joints """ |
| | def __init__(self, batch_size=1, use_face_contour=True, all_gender=False, **kwargs): |
| | super().__init__() |
| | numBetas = 10 |
| | self.use_face_contour = use_face_contour |
| | if all_gender: |
| | self.genders = ['male', 'female', 'neutral'] |
| | else: |
| | self.genders = ['neutral'] |
| | for gender in self.genders: |
| | assert gender in ['male', 'female', 'neutral'] |
| | self.model_dict = nn.ModuleDict( |
| | { |
| | gender: SMPLX( |
| | path_config.SMPL_MODEL_DIR, |
| | gender=gender, |
| | ext='npz', |
| | num_betas=numBetas, |
| | use_pca=False, |
| | batch_size=batch_size, |
| | use_face_contour=use_face_contour, |
| | num_pca_comps=45, |
| | **kwargs |
| | ) |
| | for gender in self.genders |
| | } |
| | ) |
| | self.model_neutral = self.model_dict['neutral'] |
| | joints = [constants.JOINT_MAP[i] for i in constants.JOINT_NAMES] |
| | J_regressor_extra = np.load(path_config.JOINT_REGRESSOR_TRAIN_EXTRA) |
| | self.register_buffer( |
| | 'J_regressor_extra', torch.tensor(J_regressor_extra, dtype=torch.float32) |
| | ) |
| | self.joint_map = torch.tensor(joints, dtype=torch.long) |
| | |
| | smplx_to_smpl = pickle.load( |
| | open(os.path.join(SMPL_MODEL_DIR, 'model_transfer/smplx_to_smpl.pkl'), 'rb') |
| | ) |
| | self.register_buffer( |
| | 'smplx2smpl', torch.tensor(smplx_to_smpl['matrix'][None], dtype=torch.float32) |
| | ) |
| |
|
| | smpl2limb_vert_faces = get_partial_smpl('smpl') |
| | self.smpl2lhand = torch.from_numpy(smpl2limb_vert_faces['lhand']['vids']).long() |
| | self.smpl2rhand = torch.from_numpy(smpl2limb_vert_faces['rhand']['vids']).long() |
| |
|
| | |
| | smplx2lhand_joints = [ |
| | constants.SMPLX_JOINT_IDS['left_{}'.format(name)] for name in constants.HAND_NAMES |
| | ] |
| | smplx2rhand_joints = [ |
| | constants.SMPLX_JOINT_IDS['right_{}'.format(name)] for name in constants.HAND_NAMES |
| | ] |
| | self.smplx2lh_joint_map = torch.tensor(smplx2lhand_joints, dtype=torch.long) |
| | self.smplx2rh_joint_map = torch.tensor(smplx2rhand_joints, dtype=torch.long) |
| |
|
| | |
| | smplx2lfoot_joints = [ |
| | constants.SMPLX_JOINT_IDS['left_{}'.format(name)] for name in constants.FOOT_NAMES |
| | ] |
| | smplx2rfoot_joints = [ |
| | constants.SMPLX_JOINT_IDS['right_{}'.format(name)] for name in constants.FOOT_NAMES |
| | ] |
| | self.smplx2lf_joint_map = torch.tensor(smplx2lfoot_joints, dtype=torch.long) |
| | self.smplx2rf_joint_map = torch.tensor(smplx2rfoot_joints, dtype=torch.long) |
| |
|
| | for g in self.genders: |
| | J_template = torch.einsum( |
| | 'ji,ik->jk', [self.model_dict[g].J_regressor[:24], self.model_dict[g].v_template] |
| | ) |
| | J_dirs = torch.einsum( |
| | 'ji,ikl->jkl', [self.model_dict[g].J_regressor[:24], self.model_dict[g].shapedirs] |
| | ) |
| |
|
| | self.register_buffer(f'{g}_J_template', J_template) |
| | self.register_buffer(f'{g}_J_dirs', J_dirs) |
| |
|
| | def forward(self, *args, **kwargs): |
| | batch_size = kwargs['body_pose'].shape[0] |
| | kwargs['get_skin'] = True |
| | if 'pose2rot' not in kwargs: |
| | kwargs['pose2rot'] = True |
| | if 'gender' not in kwargs: |
| | kwargs['gender'] = 2 * torch.ones(batch_size).to(kwargs['body_pose'].device) |
| |
|
| | |
| | pose_keys = [ |
| | 'global_orient', 'body_pose', 'left_hand_pose', 'right_hand_pose', 'jaw_pose', |
| | 'leye_pose', 'reye_pose' |
| | ] |
| | param_keys = ['betas'] + pose_keys |
| | if kwargs['pose2rot']: |
| | for key in pose_keys: |
| | if key in kwargs: |
| | |
| | |
| | |
| | |
| | kwargs[key] = batch_rodrigues(kwargs[key].contiguous().view(-1, 3)).view( |
| | [batch_size, -1, 3, 3] |
| | ) |
| | if kwargs['body_pose'].shape[1] == 23: |
| | |
| | kwargs['body_pose'] = kwargs['body_pose'][:, :21] |
| | gender_idx_list = [] |
| | smplx_vertices, smplx_joints = [], [] |
| | for gi, g in enumerate(['male', 'female', 'neutral']): |
| | gender_idx = ((kwargs['gender'] == gi).nonzero(as_tuple=True)[0]) |
| | if len(gender_idx) == 0: |
| | continue |
| | gender_idx_list.extend([int(idx) for idx in gender_idx]) |
| | gender_kwargs = {'get_skin': kwargs['get_skin'], 'pose2rot': kwargs['pose2rot']} |
| | gender_kwargs.update({k: kwargs[k][gender_idx] for k in param_keys if k in kwargs}) |
| | gender_smplx_output = self.model_dict[g].forward(*args, **gender_kwargs) |
| | smplx_vertices.append(gender_smplx_output.vertices) |
| | smplx_joints.append(gender_smplx_output.joints) |
| |
|
| | idx_rearrange = [gender_idx_list.index(i) for i in range(len(list(gender_idx_list)))] |
| | idx_rearrange = torch.tensor(idx_rearrange).long().to(kwargs['body_pose'].device) |
| |
|
| | smplx_vertices = torch.cat(smplx_vertices)[idx_rearrange] |
| | smplx_joints = torch.cat(smplx_joints)[idx_rearrange] |
| |
|
| | |
| | lhand_joints = smplx_joints[:, self.smplx2lh_joint_map] |
| | rhand_joints = smplx_joints[:, self.smplx2rh_joint_map] |
| | |
| | face_joints = smplx_joints[:, -68:] if self.use_face_contour else smplx_joints[:, -51:] |
| | |
| | lfoot_joints = smplx_joints[:, self.smplx2lf_joint_map] |
| | rfoot_joints = smplx_joints[:, self.smplx2rf_joint_map] |
| |
|
| | smpl_vertices = torch.bmm(self.smplx2smpl.expand(batch_size, -1, -1), smplx_vertices) |
| | lhand_vertices = smpl_vertices[:, self.smpl2lhand] |
| | rhand_vertices = smpl_vertices[:, self.smpl2rhand] |
| | extra_joints = vertices2joints(self.J_regressor_extra, smpl_vertices) |
| | |
| | smplx_j45 = smplx_joints[:, constants.SMPLX2SMPL_J45] |
| | joints = torch.cat([smplx_j45, extra_joints], dim=1) |
| | smpl_joints = smplx_j45[:, :24] |
| | joints = joints[:, self.joint_map, :] |
| | joints_J24 = joints[:, -24:, :] |
| | joints_J19 = joints_J24[:, constants.J24_TO_J19, :] |
| | output = ModelOutput( |
| | vertices=smpl_vertices, |
| | smplx_vertices=smplx_vertices, |
| | lhand_vertices=lhand_vertices, |
| | rhand_vertices=rhand_vertices, |
| | |
| | |
| | joints=joints, |
| | joints_J19=joints_J19, |
| | smpl_joints=smpl_joints, |
| | |
| | |
| | lhand_joints=lhand_joints, |
| | rhand_joints=rhand_joints, |
| | lfoot_joints=lfoot_joints, |
| | rfoot_joints=rfoot_joints, |
| | face_joints=face_joints, |
| | ) |
| | return output |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | def get_tpose(self, betas=None, gender=None): |
| | kwargs = {} |
| | if betas is None: |
| | betas = torch.zeros(1, 10).to(self.J_regressor_extra.device) |
| | kwargs['betas'] = betas |
| |
|
| | batch_size = kwargs['betas'].shape[0] |
| | device = kwargs['betas'].device |
| |
|
| | if gender is None: |
| | kwargs['gender'] = 2 * torch.ones(batch_size).to(device) |
| | else: |
| | kwargs['gender'] = gender |
| |
|
| | param_keys = ['betas'] |
| |
|
| | gender_idx_list = [] |
| | smplx_joints = [] |
| | for gi, g in enumerate(['male', 'female', 'neutral']): |
| | gender_idx = ((kwargs['gender'] == gi).nonzero(as_tuple=True)[0]) |
| | if len(gender_idx) == 0: |
| | continue |
| | gender_idx_list.extend([int(idx) for idx in gender_idx]) |
| | gender_kwargs = {} |
| | gender_kwargs.update({k: kwargs[k][gender_idx] for k in param_keys if k in kwargs}) |
| |
|
| | J = getattr(self, f'{g}_J_template').unsqueeze(0) + blend_shapes( |
| | gender_kwargs['betas'], getattr(self, f'{g}_J_dirs') |
| | ) |
| |
|
| | smplx_joints.append(J) |
| |
|
| | idx_rearrange = [gender_idx_list.index(i) for i in range(len(list(gender_idx_list)))] |
| | idx_rearrange = torch.tensor(idx_rearrange).long().to(device) |
| |
|
| | smplx_joints = torch.cat(smplx_joints)[idx_rearrange] |
| |
|
| | return smplx_joints |
| |
|
| |
|
| | class MANO(MANOLayer): |
| | """ Extension of the official MANO implementation to support more joints """ |
| | def __init__(self, *args, **kwargs): |
| | super().__init__(*args, **kwargs) |
| |
|
| | def forward(self, *args, **kwargs): |
| | if 'pose2rot' not in kwargs: |
| | kwargs['pose2rot'] = True |
| | pose_keys = ['global_orient', 'right_hand_pose'] |
| | batch_size = kwargs['global_orient'].shape[0] |
| | if kwargs['pose2rot']: |
| | for key in pose_keys: |
| | if key in kwargs: |
| | kwargs[key] = batch_rodrigues(kwargs[key].contiguous().view(-1, 3)).view( |
| | [batch_size, -1, 3, 3] |
| | ) |
| | kwargs['hand_pose'] = kwargs.pop('right_hand_pose') |
| | mano_output = super().forward(*args, **kwargs) |
| | th_verts = mano_output.vertices |
| | th_jtr = mano_output.joints |
| | |
| | |
| | |
| | tips = th_verts[:, [745, 317, 445, 556, 673]] |
| | th_jtr = torch.cat([th_jtr, tips], 1) |
| | |
| | th_jtr = th_jtr[:, |
| | [0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20]] |
| | output = ModelOutput( |
| | rhand_vertices=th_verts, |
| | rhand_joints=th_jtr, |
| | ) |
| | return output |
| |
|
| |
|
| | class FLAME(FLAMELayer): |
| | """ Extension of the official FLAME implementation to support more joints """ |
| | def __init__(self, *args, **kwargs): |
| | super().__init__(*args, **kwargs) |
| |
|
| | def forward(self, *args, **kwargs): |
| | if 'pose2rot' not in kwargs: |
| | kwargs['pose2rot'] = True |
| | pose_keys = ['global_orient', 'jaw_pose', 'leye_pose', 'reye_pose'] |
| | batch_size = kwargs['global_orient'].shape[0] |
| | if kwargs['pose2rot']: |
| | for key in pose_keys: |
| | if key in kwargs: |
| | kwargs[key] = batch_rodrigues(kwargs[key].contiguous().view(-1, 3)).view( |
| | [batch_size, -1, 3, 3] |
| | ) |
| | flame_output = super().forward(*args, **kwargs) |
| | output = ModelOutput( |
| | flame_vertices=flame_output.vertices, |
| | face_joints=flame_output.joints[:, 5:], |
| | ) |
| | return output |
| |
|
| |
|
| | class SMPL_Family(): |
| | def __init__(self, model_type='smpl', *args, **kwargs): |
| | if model_type == 'smpl': |
| | self.model = SMPL(model_path=SMPL_MODEL_DIR, *args, **kwargs) |
| | elif model_type == 'smplx': |
| | self.model = SMPLX_ALL(*args, **kwargs) |
| | elif model_type == 'mano': |
| | self.model = MANO( |
| | model_path=SMPL_MODEL_DIR, is_rhand=True, use_pca=False, *args, **kwargs |
| | ) |
| | elif model_type == 'flame': |
| | self.model = FLAME(model_path=SMPL_MODEL_DIR, use_face_contour=True, *args, **kwargs) |
| |
|
| | def __call__(self, *args, **kwargs): |
| | return self.model(*args, **kwargs) |
| |
|
| | def get_tpose(self, *args, **kwargs): |
| | return self.model.get_tpose(*args, **kwargs) |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | def get_smpl_faces(): |
| | smpl = SMPL(model_path=SMPL_MODEL_DIR, batch_size=1) |
| | return smpl.faces |
| |
|
| |
|
| | def get_smplx_faces(): |
| | smplx = SMPLX(SMPL_MODEL_DIR, batch_size=1) |
| | return smplx.faces |
| |
|
| |
|
| | def get_mano_faces(hand_type='right'): |
| | assert hand_type in ['right', 'left'] |
| | is_rhand = True if hand_type == 'right' else False |
| | mano = MANO(SMPL_MODEL_DIR, batch_size=1, is_rhand=is_rhand) |
| |
|
| | return mano.faces |
| |
|
| |
|
| | def get_flame_faces(): |
| | flame = FLAME(SMPL_MODEL_DIR, batch_size=1) |
| |
|
| | return flame.faces |
| |
|
| |
|
| | def get_model_faces(type='smpl'): |
| | if type == 'smpl': |
| | return get_smpl_faces() |
| | elif type == 'smplx': |
| | return get_smplx_faces() |
| | elif type == 'mano': |
| | return get_mano_faces() |
| | elif type == 'flame': |
| | return get_flame_faces() |
| |
|
| |
|
| | def get_model_tpose(type='smpl'): |
| | if type == 'smpl': |
| | return get_smpl_tpose() |
| | elif type == 'smplx': |
| | return get_smplx_tpose() |
| | elif type == 'mano': |
| | return get_mano_tpose() |
| | elif type == 'flame': |
| | return get_flame_tpose() |
| |
|
| |
|
| | def get_smpl_tpose(): |
| | smpl = SMPL( |
| | create_betas=True, |
| | create_global_orient=True, |
| | create_body_pose=True, |
| | model_path=SMPL_MODEL_DIR, |
| | batch_size=1 |
| | ) |
| | vertices = smpl().vertices[0] |
| | return vertices.detach() |
| |
|
| |
|
| | def get_smpl_tpose_joint(): |
| | smpl = SMPL( |
| | create_betas=True, |
| | create_global_orient=True, |
| | create_body_pose=True, |
| | model_path=SMPL_MODEL_DIR, |
| | batch_size=1 |
| | ) |
| | tpose_joint = smpl().smpl_joints[0] |
| | return tpose_joint.detach() |
| |
|
| |
|
| | def get_smplx_tpose(): |
| | smplx = SMPLXLayer(SMPL_MODEL_DIR, batch_size=1) |
| | vertices = smplx().vertices[0] |
| | return vertices |
| |
|
| |
|
| | def get_smplx_tpose_joint(): |
| | smplx = SMPLXLayer(SMPL_MODEL_DIR, batch_size=1) |
| | tpose_joint = smplx().joints[0] |
| | return tpose_joint |
| |
|
| |
|
| | def get_mano_tpose(): |
| | mano = MANO(SMPL_MODEL_DIR, batch_size=1, is_rhand=True) |
| | vertices = mano(global_orient=torch.zeros(1, 3), |
| | right_hand_pose=torch.zeros(1, 15 * 3)).rhand_vertices[0] |
| | return vertices |
| |
|
| |
|
| | def get_flame_tpose(): |
| | flame = FLAME(SMPL_MODEL_DIR, batch_size=1) |
| | vertices = flame(global_orient=torch.zeros(1, 3)).flame_vertices[0] |
| | return vertices |
| |
|
| |
|
| | def get_part_joints(smpl_joints): |
| | batch_size = smpl_joints.shape[0] |
| |
|
| | |
| |
|
| | one_seg_pairs = [ |
| | (0, 1), (0, 2), (0, 3), (3, 6), (9, 12), (9, 13), (9, 14), (12, 15), (13, 16), (14, 17) |
| | ] |
| | two_seg_pairs = [(1, 4), (2, 5), (4, 7), (5, 8), (16, 18), (17, 19), (18, 20), (19, 21)] |
| |
|
| | one_seg_pairs.extend(two_seg_pairs) |
| |
|
| | single_joints = [(10), (11), (15), (22), (23)] |
| |
|
| | part_joints = [] |
| |
|
| | for j_p in one_seg_pairs: |
| | new_joint = torch.mean(smpl_joints[:, j_p], dim=1, keepdim=True) |
| | part_joints.append(new_joint) |
| |
|
| | for j_p in single_joints: |
| | part_joints.append(smpl_joints[:, j_p:j_p + 1]) |
| |
|
| | part_joints = torch.cat(part_joints, dim=1) |
| |
|
| | return part_joints |
| |
|
| |
|
| | def get_partial_smpl(body_model='smpl', device=torch.device('cuda')): |
| |
|
| | body_model_faces = get_model_faces(body_model) |
| | body_model_num_verts = len(get_model_tpose(body_model)) |
| |
|
| | part_vert_faces = {} |
| |
|
| | for part in ['lhand', 'rhand', 'face', 'arm', 'forearm', 'larm', 'rarm', 'lwrist', 'rwrist']: |
| | part_vid_fname = '{}/{}_{}_vids.npz'.format(path_config.PARTIAL_MESH_DIR, body_model, part) |
| | if os.path.exists(part_vid_fname): |
| | part_vids = np.load(part_vid_fname) |
| | part_vert_faces[part] = {'vids': part_vids['vids'], 'faces': part_vids['faces']} |
| | else: |
| | if part in ['lhand', 'rhand']: |
| | with open( |
| | os.path.join(SMPL_MODEL_DIR, 'model_transfer/MANO_SMPLX_vertex_ids.pkl'), 'rb' |
| | ) as json_file: |
| | smplx_mano_id = pickle.load(json_file) |
| | with open( |
| | os.path.join(SMPL_MODEL_DIR, 'model_transfer/smplx_to_smpl.pkl'), 'rb' |
| | ) as json_file: |
| | smplx_smpl_id = pickle.load(json_file) |
| |
|
| | smplx_tpose = get_smplx_tpose() |
| | smpl_tpose = np.matmul(smplx_smpl_id['matrix'], smplx_tpose) |
| |
|
| | if part == 'lhand': |
| | mano_vert = smplx_tpose[smplx_mano_id['left_hand']] |
| | elif part == 'rhand': |
| | mano_vert = smplx_tpose[smplx_mano_id['right_hand']] |
| |
|
| | smpl2mano_id = [] |
| | for vert in mano_vert: |
| | v_diff = smpl_tpose - vert |
| | v_diff = torch.sum(v_diff * v_diff, dim=1) |
| | v_closest = torch.argmin(v_diff) |
| | smpl2mano_id.append(int(v_closest)) |
| |
|
| | smpl2mano_vids = np.array(smpl2mano_id).astype(np.longlong) |
| | mano_faces = get_mano_faces(hand_type='right' if part == 'rhand' else 'left' |
| | ).astype(np.longlong) |
| |
|
| | np.savez(part_vid_fname, vids=smpl2mano_vids, faces=mano_faces) |
| | part_vert_faces[part] = {'vids': smpl2mano_vids, 'faces': mano_faces} |
| |
|
| | elif part in ['face', 'arm', 'forearm', 'larm', 'rarm']: |
| | with open( |
| | os.path.join(SMPL_MODEL_DIR, '{}_vert_segmentation.json'.format(body_model)), |
| | 'rb' |
| | ) as json_file: |
| | smplx_part_id = json.load(json_file) |
| |
|
| | |
| | |
| |
|
| | if part == 'face': |
| | selected_body_part = ['head'] |
| | elif part == 'arm': |
| | selected_body_part = [ |
| | 'rightHand', |
| | 'leftArm', |
| | 'leftShoulder', |
| | 'rightShoulder', |
| | 'rightArm', |
| | 'leftHandIndex1', |
| | 'rightHandIndex1', |
| | 'leftForeArm', |
| | 'rightForeArm', |
| | 'leftHand', |
| | ] |
| | |
| | elif part == 'forearm': |
| | selected_body_part = [ |
| | 'rightHand', |
| | 'leftHandIndex1', |
| | 'rightHandIndex1', |
| | 'leftForeArm', |
| | 'rightForeArm', |
| | 'leftHand', |
| | ] |
| | elif part == 'arm_eval': |
| | selected_body_part = ['leftArm', 'rightArm', 'leftForeArm', 'rightForeArm'] |
| | elif part == 'larm': |
| | |
| | selected_body_part = ['leftForeArm'] |
| | elif part == 'rarm': |
| | |
| | selected_body_part = ['rightForeArm'] |
| |
|
| | part_body_idx = [] |
| | for k in selected_body_part: |
| | part_body_idx.extend(smplx_part_id[k]) |
| |
|
| | part_body_fid = [] |
| | for f_id, face in enumerate(body_model_faces): |
| | if any(f in part_body_idx for f in face): |
| | part_body_fid.append(f_id) |
| |
|
| | smpl2head_vids = np.unique(body_model_faces[part_body_fid]).astype(np.longlong) |
| |
|
| | mesh_vid_raw = np.arange(body_model_num_verts) |
| | head_vid_new = np.arange(len(smpl2head_vids)) |
| | mesh_vid_raw[smpl2head_vids] = head_vid_new |
| |
|
| | head_faces = body_model_faces[part_body_fid] |
| | head_faces = mesh_vid_raw[head_faces].astype(np.longlong) |
| |
|
| | np.savez(part_vid_fname, vids=smpl2head_vids, faces=head_faces) |
| | part_vert_faces[part] = {'vids': smpl2head_vids, 'faces': head_faces} |
| |
|
| | elif part in ['lwrist', 'rwrist']: |
| |
|
| | if body_model == 'smplx': |
| | body_model_verts = get_smplx_tpose() |
| | tpose_joint = get_smplx_tpose_joint() |
| | elif body_model == 'smpl': |
| | body_model_verts = get_smpl_tpose() |
| | tpose_joint = get_smpl_tpose_joint() |
| |
|
| | wrist_joint = tpose_joint[20] if part == 'lwrist' else tpose_joint[21] |
| |
|
| | dist = 0.005 |
| | wrist_vids = [] |
| | for vid, vt in enumerate(body_model_verts): |
| |
|
| | v_j_dist = torch.sum((vt - wrist_joint)**2) |
| |
|
| | if v_j_dist < dist: |
| | wrist_vids.append(vid) |
| |
|
| | wrist_vids = np.array(wrist_vids) |
| |
|
| | part_body_fid = [] |
| | for f_id, face in enumerate(body_model_faces): |
| | if any(f in wrist_vids for f in face): |
| | part_body_fid.append(f_id) |
| |
|
| | smpl2part_vids = np.unique(body_model_faces[part_body_fid]).astype(np.longlong) |
| |
|
| | mesh_vid_raw = np.arange(body_model_num_verts) |
| | part_vid_new = np.arange(len(smpl2part_vids)) |
| | mesh_vid_raw[smpl2part_vids] = part_vid_new |
| |
|
| | part_faces = body_model_faces[part_body_fid] |
| | part_faces = mesh_vid_raw[part_faces].astype(np.longlong) |
| |
|
| | np.savez(part_vid_fname, vids=smpl2part_vids, faces=part_faces) |
| | part_vert_faces[part] = {'vids': smpl2part_vids, 'faces': part_faces} |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | return part_vert_faces |
| |
|