| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from typing import Optional, Dict, Union |
| import os |
| import os.path as osp |
| import pickle |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from collections import namedtuple |
|
|
| import logging |
|
|
| logging.getLogger("smplx").setLevel(logging.ERROR) |
|
|
| from .lbs import (lbs, vertices2landmarks, find_dynamic_lmk_idx_and_bcoords) |
|
|
| from .vertex_ids import vertex_ids as VERTEX_IDS |
| from .utils import (Struct, to_np, to_tensor, Tensor, Array, SMPLOutput, |
| SMPLHOutput, SMPLXOutput, MANOOutput, FLAMEOutput, |
| find_joint_kin_chain) |
| from .vertex_joint_selector import VertexJointSelector |
|
|
| ModelOutput = namedtuple('ModelOutput', [ |
| 'vertices', 'joints', 'full_pose', 'betas', 'global_orient', 'body_pose', |
| 'expression', 'left_hand_pose', 'right_hand_pose', 'jaw_pose' |
| ]) |
| ModelOutput.__new__.__defaults__ = (None, ) * len(ModelOutput._fields) |
|
|
|
|
| class SMPL(nn.Module): |
|
|
| NUM_JOINTS = 23 |
| NUM_BODY_JOINTS = 23 |
| SHAPE_SPACE_DIM = 300 |
|
|
| def __init__(self, |
| model_path: str, |
| kid_template_path: str = '', |
| data_struct: Optional[Struct] = None, |
| create_betas: bool = True, |
| betas: Optional[Tensor] = None, |
| num_betas: int = 10, |
| create_global_orient: bool = True, |
| global_orient: Optional[Tensor] = None, |
| create_body_pose: bool = True, |
| body_pose: Optional[Tensor] = None, |
| create_transl: bool = True, |
| transl: Optional[Tensor] = None, |
| dtype=torch.float32, |
| batch_size: int = 1, |
| joint_mapper=None, |
| gender: str = 'neutral', |
| age: str = 'adult', |
| vertex_ids: Dict[str, int] = None, |
| v_template: Optional[Union[Tensor, Array]] = None, |
| v_personal: Optional[Union[Tensor, Array]] = None, |
| **kwargs) -> None: |
| ''' SMPL model constructor |
| |
| Parameters |
| ---------- |
| model_path: str |
| The path to the folder or to the file where the model |
| parameters are stored |
| data_struct: Strct |
| A struct object. If given, then the parameters of the model are |
| read from the object. Otherwise, the model tries to read the |
| parameters from the given `model_path`. (default = None) |
| create_global_orient: bool, optional |
| Flag for creating a member variable for the global orientation |
| of the body. (default = True) |
| global_orient: torch.tensor, optional, Bx3 |
| The default value for the global orientation variable. |
| (default = None) |
| create_body_pose: bool, optional |
| Flag for creating a member variable for the pose of the body. |
| (default = True) |
| body_pose: torch.tensor, optional, Bx(Body Joints * 3) |
| The default value for the body pose variable. |
| (default = None) |
| num_betas: int, optional |
| Number of shape components to use |
| (default = 10). |
| create_betas: bool, optional |
| Flag for creating a member variable for the shape space |
| (default = True). |
| betas: torch.tensor, optional, Bx10 |
| The default value for the shape member variable. |
| (default = None) |
| create_transl: bool, optional |
| Flag for creating a member variable for the translation |
| of the body. (default = True) |
| transl: torch.tensor, optional, Bx3 |
| The default value for the transl variable. |
| (default = None) |
| dtype: torch.dtype, optional |
| The data type for the created variables |
| batch_size: int, optional |
| The batch size used for creating the member variables |
| joint_mapper: object, optional |
| An object that re-maps the joints. Useful if one wants to |
| re-order the SMPL joints to some other convention (e.g. MSCOCO) |
| (default = None) |
| gender: str, optional |
| Which gender to load |
| vertex_ids: dict, optional |
| A dictionary containing the indices of the extra vertices that |
| will be selected |
| ''' |
|
|
| self.gender = gender |
| self.age = age |
|
|
| if data_struct is None: |
| if osp.isdir(model_path): |
| model_fn = 'SMPL_{}.{ext}'.format(gender.upper(), ext='pkl') |
| smpl_path = os.path.join(model_path, model_fn) |
| else: |
| smpl_path = model_path |
| assert osp.exists(smpl_path), 'Path {} does not exist!'.format( |
| smpl_path) |
|
|
| with open(smpl_path, 'rb') as smpl_file: |
| data_struct = Struct( |
| **pickle.load(smpl_file, encoding='latin1')) |
|
|
| super(SMPL, self).__init__() |
| self.batch_size = batch_size |
| shapedirs = data_struct.shapedirs |
| if (shapedirs.shape[-1] < self.SHAPE_SPACE_DIM): |
| |
| |
| num_betas = min(num_betas, 10) |
| else: |
| num_betas = min(num_betas, self.SHAPE_SPACE_DIM) |
|
|
| if self.age == 'kid': |
| v_template_smil = np.load(kid_template_path) |
| v_template_smil -= np.mean(v_template_smil, axis=0) |
| v_template_diff = np.expand_dims(v_template_smil - |
| data_struct.v_template, |
| axis=2) |
| shapedirs = np.concatenate( |
| (shapedirs[:, :, :num_betas], v_template_diff), axis=2) |
| num_betas = num_betas + 1 |
|
|
| self._num_betas = num_betas |
| shapedirs = shapedirs[:, :, :num_betas] |
| |
| self.register_buffer('shapedirs', |
| to_tensor(to_np(shapedirs), dtype=dtype)) |
|
|
| if vertex_ids is None: |
| |
| |
| vertex_ids = VERTEX_IDS['smplh'] |
|
|
| self.dtype = dtype |
|
|
| self.joint_mapper = joint_mapper |
|
|
| self.vertex_joint_selector = VertexJointSelector(vertex_ids=vertex_ids, |
| **kwargs) |
|
|
| self.faces = data_struct.f |
| self.register_buffer( |
| 'faces_tensor', |
| to_tensor(to_np(self.faces, dtype=np.int64), dtype=torch.long)) |
|
|
| if create_betas: |
| if betas is None: |
| default_betas = torch.zeros([batch_size, self.num_betas], |
| dtype=dtype) |
| else: |
| if torch.is_tensor(betas): |
| default_betas = betas.clone().detach() |
| else: |
| default_betas = torch.tensor(betas, dtype=dtype) |
|
|
| self.register_parameter( |
| 'betas', nn.Parameter(default_betas, requires_grad=True)) |
|
|
| |
| |
| |
| if create_global_orient: |
| if global_orient is None: |
| default_global_orient = torch.zeros([batch_size, 3], |
| dtype=dtype) |
| else: |
| if torch.is_tensor(global_orient): |
| default_global_orient = global_orient.clone().detach() |
| else: |
| default_global_orient = torch.tensor(global_orient, |
| dtype=dtype) |
|
|
| global_orient = nn.Parameter(default_global_orient, |
| requires_grad=True) |
| self.register_parameter('global_orient', global_orient) |
|
|
| if create_body_pose: |
| if body_pose is None: |
| default_body_pose = torch.zeros( |
| [batch_size, self.NUM_BODY_JOINTS * 3], dtype=dtype) |
| else: |
| if torch.is_tensor(body_pose): |
| default_body_pose = body_pose.clone().detach() |
| else: |
| default_body_pose = torch.tensor(body_pose, dtype=dtype) |
| self.register_parameter( |
| 'body_pose', nn.Parameter(default_body_pose, |
| requires_grad=True)) |
|
|
| if create_transl: |
| if transl is None: |
| default_transl = torch.zeros([batch_size, 3], |
| dtype=dtype, |
| requires_grad=True) |
| else: |
| default_transl = torch.tensor(transl, dtype=dtype) |
| self.register_parameter( |
| 'transl', nn.Parameter(default_transl, requires_grad=True)) |
|
|
| if v_template is None: |
| v_template = data_struct.v_template |
|
|
| if not torch.is_tensor(v_template): |
| v_template = to_tensor(to_np(v_template), dtype=dtype) |
|
|
| if v_personal is not None: |
| v_personal = to_tensor(to_np(v_personal), dtype=dtype) |
| v_template += v_personal |
|
|
| |
| self.register_buffer('v_template', v_template) |
|
|
| j_regressor = to_tensor(to_np(data_struct.J_regressor), dtype=dtype) |
| self.register_buffer('J_regressor', j_regressor) |
|
|
| |
| num_pose_basis = data_struct.posedirs.shape[-1] |
| |
| posedirs = np.reshape(data_struct.posedirs, [-1, num_pose_basis]).T |
| self.register_buffer('posedirs', to_tensor(to_np(posedirs), |
| dtype=dtype)) |
|
|
| |
| parents = to_tensor(to_np(data_struct.kintree_table[0])).long() |
| parents[0] = -1 |
| self.register_buffer('parents', parents) |
|
|
| self.register_buffer( |
| 'lbs_weights', to_tensor(to_np(data_struct.weights), dtype=dtype)) |
|
|
| @property |
| def num_betas(self): |
| return self._num_betas |
|
|
| @property |
| def num_expression_coeffs(self): |
| return 0 |
|
|
| def create_mean_pose(self, data_struct) -> Tensor: |
| pass |
|
|
| def name(self) -> str: |
| return 'SMPL' |
|
|
| @torch.no_grad() |
| def reset_params(self, **params_dict) -> None: |
| for param_name, param in self.named_parameters(): |
| if param_name in params_dict: |
| param[:] = torch.tensor(params_dict[param_name]) |
| else: |
| param.fill_(0) |
|
|
| def get_num_verts(self) -> int: |
| return self.v_template.shape[0] |
|
|
| def get_num_faces(self) -> int: |
| return self.faces.shape[0] |
|
|
| def extra_repr(self) -> str: |
| msg = [ |
| f'Gender: {self.gender.upper()}', |
| f'Number of joints: {self.J_regressor.shape[0]}', |
| f'Betas: {self.num_betas}', |
| ] |
| return '\n'.join(msg) |
|
|
| def forward(self, |
| betas: Optional[Tensor] = None, |
| body_pose: Optional[Tensor] = None, |
| global_orient: Optional[Tensor] = None, |
| transl: Optional[Tensor] = None, |
| return_verts=True, |
| return_full_pose: bool = False, |
| pose2rot: bool = True, |
| **kwargs) -> SMPLOutput: |
| ''' Forward pass for the SMPL model |
| |
| Parameters |
| ---------- |
| global_orient: torch.tensor, optional, shape Bx3 |
| If given, ignore the member variable and use it as the global |
| rotation of the body. Useful if someone wishes to predicts this |
| with an external model. (default=None) |
| betas: torch.tensor, optional, shape BxN_b |
| If given, ignore the member variable `betas` and use it |
| instead. For example, it can used if shape parameters |
| `betas` are predicted from some external model. |
| (default=None) |
| body_pose: torch.tensor, optional, shape Bx(J*3) |
| If given, ignore the member variable `body_pose` and use it |
| instead. For example, it can used if someone predicts the |
| pose of the body joints are predicted from some external model. |
| It should be a tensor that contains joint rotations in |
| axis-angle format. (default=None) |
| transl: torch.tensor, optional, shape Bx3 |
| If given, ignore the member variable `transl` and use it |
| instead. For example, it can used if the translation |
| `transl` is predicted from some external model. |
| (default=None) |
| return_verts: bool, optional |
| Return the vertices. (default=True) |
| return_full_pose: bool, optional |
| Returns the full axis-angle pose vector (default=False) |
| |
| Returns |
| ------- |
| ''' |
| |
| |
| global_orient = (global_orient |
| if global_orient is not None else self.global_orient) |
| body_pose = body_pose if body_pose is not None else self.body_pose |
| betas = betas if betas is not None else self.betas |
|
|
| apply_trans = transl is not None or hasattr(self, 'transl') |
| if transl is None and hasattr(self, 'transl'): |
| transl = self.transl |
|
|
| full_pose = torch.cat([global_orient, body_pose], dim=1) |
|
|
| batch_size = max(betas.shape[0], global_orient.shape[0], |
| body_pose.shape[0]) |
|
|
| if betas.shape[0] != batch_size: |
| num_repeats = int(batch_size / betas.shape[0]) |
| betas = betas.expand(num_repeats, -1) |
|
|
| vertices, joints = lbs(betas, |
| full_pose, |
| self.v_template, |
| self.shapedirs, |
| self.posedirs, |
| self.J_regressor, |
| self.parents, |
| self.lbs_weights, |
| pose2rot=pose2rot) |
|
|
| joints = self.vertex_joint_selector(vertices, joints) |
| |
| if self.joint_mapper is not None: |
| joints = self.joint_mapper(joints) |
|
|
| if apply_trans: |
| joints += transl.unsqueeze(dim=1) |
| vertices += transl.unsqueeze(dim=1) |
|
|
| output = SMPLOutput(vertices=vertices if return_verts else None, |
| global_orient=global_orient, |
| body_pose=body_pose, |
| joints=joints, |
| betas=betas, |
| full_pose=full_pose if return_full_pose else None) |
|
|
| return output |
|
|
|
|
| class SMPLLayer(SMPL): |
|
|
| def __init__(self, *args, **kwargs) -> None: |
| |
| super(SMPLLayer, self).__init__( |
| create_body_pose=False, |
| create_betas=False, |
| create_global_orient=False, |
| create_transl=False, |
| *args, |
| **kwargs, |
| ) |
|
|
| def forward(self, |
| betas: Optional[Tensor] = None, |
| body_pose: Optional[Tensor] = None, |
| global_orient: Optional[Tensor] = None, |
| transl: Optional[Tensor] = None, |
| return_verts=True, |
| return_full_pose: bool = False, |
| pose2rot: bool = True, |
| **kwargs) -> SMPLOutput: |
| ''' Forward pass for the SMPL model |
| |
| Parameters |
| ---------- |
| global_orient: torch.tensor, optional, shape Bx3x3 |
| Global rotation of the body. Useful if someone wishes to |
| predicts this with an external model. It is expected to be in |
| rotation matrix format. (default=None) |
| betas: torch.tensor, optional, shape BxN_b |
| Shape parameters. For example, it can used if shape parameters |
| `betas` are predicted from some external model. |
| (default=None) |
| body_pose: torch.tensor, optional, shape BxJx3x3 |
| Body pose. For example, it can used if someone predicts the |
| pose of the body joints are predicted from some external model. |
| It should be a tensor that contains joint rotations in |
| rotation matrix format. (default=None) |
| transl: torch.tensor, optional, shape Bx3 |
| Translation vector of the body. |
| For example, it can used if the translation |
| `transl` is predicted from some external model. |
| (default=None) |
| return_verts: bool, optional |
| Return the vertices. (default=True) |
| return_full_pose: bool, optional |
| Returns the full axis-angle pose vector (default=False) |
| |
| Returns |
| ------- |
| ''' |
| model_vars = [betas, global_orient, body_pose, transl] |
| batch_size = 1 |
| for var in model_vars: |
| if var is None: |
| continue |
| batch_size = max(batch_size, len(var)) |
| device, dtype = self.shapedirs.device, self.shapedirs.dtype |
| if global_orient is None: |
| global_orient = torch.eye(3, device=device, dtype=dtype).view( |
| 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() |
| if body_pose is None: |
| body_pose = torch.eye(3, device=device, dtype=dtype).view( |
| 1, 1, 3, 3).expand(batch_size, self.NUM_BODY_JOINTS, -1, |
| -1).contiguous() |
| if betas is None: |
| betas = torch.zeros([batch_size, self.num_betas], |
| dtype=dtype, |
| device=device) |
| if transl is None: |
| transl = torch.zeros([batch_size, 3], dtype=dtype, device=device) |
| full_pose = torch.cat([ |
| global_orient.reshape(-1, 1, 3, 3), |
| body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3) |
| ], |
| dim=1) |
|
|
| vertices, joints = lbs(betas, |
| full_pose, |
| self.v_template, |
| self.shapedirs, |
| self.posedirs, |
| self.J_regressor, |
| self.parents, |
| self.lbs_weights, |
| pose2rot=False) |
|
|
| joints = self.vertex_joint_selector(vertices, joints) |
| |
| if self.joint_mapper is not None: |
| joints = self.joint_mapper(joints) |
|
|
| if transl is not None: |
| joints += transl.unsqueeze(dim=1) |
| vertices += transl.unsqueeze(dim=1) |
|
|
| output = SMPLOutput(vertices=vertices if return_verts else None, |
| global_orient=global_orient, |
| body_pose=body_pose, |
| joints=joints, |
| betas=betas, |
| full_pose=full_pose if return_full_pose else None) |
|
|
| return output |
|
|
|
|
| class SMPLH(SMPL): |
|
|
| |
| NUM_BODY_JOINTS = SMPL.NUM_JOINTS - 2 |
| NUM_HAND_JOINTS = 15 |
| NUM_JOINTS = NUM_BODY_JOINTS + 2 * NUM_HAND_JOINTS |
|
|
| def __init__(self, |
| model_path, |
| kid_template_path: str = '', |
| data_struct: Optional[Struct] = None, |
| create_left_hand_pose: bool = True, |
| left_hand_pose: Optional[Tensor] = None, |
| create_right_hand_pose: bool = True, |
| right_hand_pose: Optional[Tensor] = None, |
| use_pca: bool = True, |
| num_pca_comps: int = 6, |
| flat_hand_mean: bool = False, |
| batch_size: int = 1, |
| gender: str = 'neutral', |
| age: str = 'adult', |
| dtype=torch.float32, |
| vertex_ids=None, |
| use_compressed: bool = True, |
| ext: str = 'pkl', |
| **kwargs) -> None: |
| ''' SMPLH model constructor |
| |
| Parameters |
| ---------- |
| model_path: str |
| The path to the folder or to the file where the model |
| parameters are stored |
| data_struct: Strct |
| A struct object. If given, then the parameters of the model are |
| read from the object. Otherwise, the model tries to read the |
| parameters from the given `model_path`. (default = None) |
| create_left_hand_pose: bool, optional |
| Flag for creating a member variable for the pose of the left |
| hand. (default = True) |
| left_hand_pose: torch.tensor, optional, BxP |
| The default value for the left hand pose member variable. |
| (default = None) |
| create_right_hand_pose: bool, optional |
| Flag for creating a member variable for the pose of the right |
| hand. (default = True) |
| right_hand_pose: torch.tensor, optional, BxP |
| The default value for the right hand pose member variable. |
| (default = None) |
| num_pca_comps: int, optional |
| The number of PCA components to use for each hand. |
| (default = 6) |
| flat_hand_mean: bool, optional |
| If False, then the pose of the hand is initialized to False. |
| batch_size: int, optional |
| The batch size used for creating the member variables |
| gender: str, optional |
| Which gender to load |
| dtype: torch.dtype, optional |
| The data type for the created variables |
| vertex_ids: dict, optional |
| A dictionary containing the indices of the extra vertices that |
| will be selected |
| ''' |
|
|
| self.num_pca_comps = num_pca_comps |
| |
| |
| if data_struct is None: |
| |
| if osp.isdir(model_path): |
| model_fn = 'SMPLH_{}.{ext}'.format(gender.upper(), ext=ext) |
| smplh_path = os.path.join(model_path, model_fn) |
| else: |
| smplh_path = model_path |
| assert osp.exists(smplh_path), 'Path {} does not exist!'.format( |
| smplh_path) |
|
|
| if ext == 'pkl': |
| with open(smplh_path, 'rb') as smplh_file: |
| model_data = pickle.load(smplh_file, encoding='latin1') |
| elif ext == 'npz': |
| model_data = np.load(smplh_path, allow_pickle=True) |
| else: |
| raise ValueError('Unknown extension: {}'.format(ext)) |
| data_struct = Struct(**model_data) |
|
|
| if vertex_ids is None: |
| vertex_ids = VERTEX_IDS['smplh'] |
|
|
| super(SMPLH, self).__init__(model_path=model_path, |
| kid_template_path=kid_template_path, |
| data_struct=data_struct, |
| batch_size=batch_size, |
| vertex_ids=vertex_ids, |
| gender=gender, |
| age=age, |
| use_compressed=use_compressed, |
| dtype=dtype, |
| ext=ext, |
| **kwargs) |
|
|
| self.use_pca = use_pca |
| self.num_pca_comps = num_pca_comps |
| self.flat_hand_mean = flat_hand_mean |
|
|
| left_hand_components = data_struct.hands_componentsl[:num_pca_comps] |
| right_hand_components = data_struct.hands_componentsr[:num_pca_comps] |
|
|
| self.np_left_hand_components = left_hand_components |
| self.np_right_hand_components = right_hand_components |
| if self.use_pca: |
| self.register_buffer( |
| 'left_hand_components', |
| torch.tensor(left_hand_components, dtype=dtype)) |
| self.register_buffer( |
| 'right_hand_components', |
| torch.tensor(right_hand_components, dtype=dtype)) |
|
|
| if self.flat_hand_mean: |
| left_hand_mean = np.zeros_like(data_struct.hands_meanl) |
| else: |
| left_hand_mean = data_struct.hands_meanl |
|
|
| if self.flat_hand_mean: |
| right_hand_mean = np.zeros_like(data_struct.hands_meanr) |
| else: |
| right_hand_mean = data_struct.hands_meanr |
|
|
| self.register_buffer('left_hand_mean', |
| to_tensor(left_hand_mean, dtype=self.dtype)) |
| self.register_buffer('right_hand_mean', |
| to_tensor(right_hand_mean, dtype=self.dtype)) |
|
|
| |
| hand_pose_dim = num_pca_comps if use_pca else 3 * self.NUM_HAND_JOINTS |
| if create_left_hand_pose: |
| if left_hand_pose is None: |
| default_lhand_pose = torch.zeros([batch_size, hand_pose_dim], |
| dtype=dtype) |
| else: |
| default_lhand_pose = torch.tensor(left_hand_pose, dtype=dtype) |
|
|
| left_hand_pose_param = nn.Parameter(default_lhand_pose, |
| requires_grad=True) |
| self.register_parameter('left_hand_pose', left_hand_pose_param) |
|
|
| if create_right_hand_pose: |
| if right_hand_pose is None: |
| default_rhand_pose = torch.zeros([batch_size, hand_pose_dim], |
| dtype=dtype) |
| else: |
| default_rhand_pose = torch.tensor(right_hand_pose, dtype=dtype) |
|
|
| right_hand_pose_param = nn.Parameter(default_rhand_pose, |
| requires_grad=True) |
| self.register_parameter('right_hand_pose', right_hand_pose_param) |
|
|
| |
| pose_mean_tensor = self.create_mean_pose(data_struct, |
| flat_hand_mean=flat_hand_mean) |
| if not torch.is_tensor(pose_mean_tensor): |
| pose_mean_tensor = torch.tensor(pose_mean_tensor, dtype=dtype) |
| self.register_buffer('pose_mean', pose_mean_tensor) |
|
|
| def create_mean_pose(self, data_struct, flat_hand_mean=False): |
| |
| |
| global_orient_mean = torch.zeros([3], dtype=self.dtype) |
| body_pose_mean = torch.zeros([self.NUM_BODY_JOINTS * 3], |
| dtype=self.dtype) |
|
|
| pose_mean = torch.cat([ |
| global_orient_mean, body_pose_mean, self.left_hand_mean, |
| self.right_hand_mean |
| ], |
| dim=0) |
| return pose_mean |
|
|
| def name(self) -> str: |
| return 'SMPL+H' |
|
|
| def extra_repr(self): |
| msg = super(SMPLH, self).extra_repr() |
| msg = [msg] |
| if self.use_pca: |
| msg.append(f'Number of PCA components: {self.num_pca_comps}') |
| msg.append(f'Flat hand mean: {self.flat_hand_mean}') |
| return '\n'.join(msg) |
|
|
| def forward(self, |
| betas: Optional[Tensor] = None, |
| global_orient: Optional[Tensor] = None, |
| body_pose: Optional[Tensor] = None, |
| left_hand_pose: Optional[Tensor] = None, |
| right_hand_pose: Optional[Tensor] = None, |
| transl: Optional[Tensor] = None, |
| return_verts: bool = True, |
| return_full_pose: bool = False, |
| pose2rot: bool = True, |
| **kwargs) -> SMPLHOutput: |
| ''' |
| ''' |
|
|
| |
| |
| global_orient = (global_orient |
| if global_orient is not None else self.global_orient) |
| body_pose = body_pose if body_pose is not None else self.body_pose |
| betas = betas if betas is not None else self.betas |
| left_hand_pose = (left_hand_pose if left_hand_pose is not None else |
| self.left_hand_pose) |
| right_hand_pose = (right_hand_pose if right_hand_pose is not None else |
| self.right_hand_pose) |
|
|
| apply_trans = transl is not None or hasattr(self, 'transl') |
| if transl is None: |
| if hasattr(self, 'transl'): |
| transl = self.transl |
|
|
| if self.use_pca: |
| left_hand_pose = torch.einsum( |
| 'bi,ij->bj', [left_hand_pose, self.left_hand_components]) |
| right_hand_pose = torch.einsum( |
| 'bi,ij->bj', [right_hand_pose, self.right_hand_components]) |
|
|
| full_pose = torch.cat( |
| [global_orient, body_pose, left_hand_pose, right_hand_pose], dim=1) |
|
|
| full_pose += self.pose_mean |
|
|
| vertices, joints = lbs(betas, |
| full_pose, |
| self.v_template, |
| self.shapedirs, |
| self.posedirs, |
| self.J_regressor, |
| self.parents, |
| self.lbs_weights, |
| pose2rot=pose2rot) |
|
|
| |
| joints = self.vertex_joint_selector(vertices, joints) |
| if self.joint_mapper is not None: |
| joints = self.joint_mapper(joints) |
|
|
| if apply_trans: |
| joints += transl.unsqueeze(dim=1) |
| vertices += transl.unsqueeze(dim=1) |
|
|
| output = SMPLHOutput(vertices=vertices if return_verts else None, |
| joints=joints, |
| betas=betas, |
| global_orient=global_orient, |
| body_pose=body_pose, |
| left_hand_pose=left_hand_pose, |
| right_hand_pose=right_hand_pose, |
| full_pose=full_pose if return_full_pose else None) |
|
|
| return output |
|
|
|
|
| class SMPLHLayer(SMPLH): |
|
|
| def __init__(self, *args, **kwargs) -> None: |
| ''' SMPL+H as a layer model constructor |
| ''' |
| super(SMPLHLayer, self).__init__(create_global_orient=False, |
| create_body_pose=False, |
| create_left_hand_pose=False, |
| create_right_hand_pose=False, |
| create_betas=False, |
| create_transl=False, |
| *args, |
| **kwargs) |
|
|
| def forward(self, |
| betas: Optional[Tensor] = None, |
| global_orient: Optional[Tensor] = None, |
| body_pose: Optional[Tensor] = None, |
| left_hand_pose: Optional[Tensor] = None, |
| right_hand_pose: Optional[Tensor] = None, |
| transl: Optional[Tensor] = None, |
| return_verts: bool = True, |
| return_full_pose: bool = False, |
| pose2rot: bool = True, |
| **kwargs) -> SMPLHOutput: |
| ''' Forward pass for the SMPL+H model |
| |
| Parameters |
| ---------- |
| global_orient: torch.tensor, optional, shape Bx3x3 |
| Global rotation of the body. Useful if someone wishes to |
| predicts this with an external model. It is expected to be in |
| rotation matrix format. (default=None) |
| betas: torch.tensor, optional, shape BxN_b |
| Shape parameters. For example, it can used if shape parameters |
| `betas` are predicted from some external model. |
| (default=None) |
| body_pose: torch.tensor, optional, shape BxJx3x3 |
| If given, ignore the member variable `body_pose` and use it |
| instead. For example, it can used if someone predicts the |
| pose of the body joints are predicted from some external model. |
| It should be a tensor that contains joint rotations in |
| rotation matrix format. (default=None) |
| left_hand_pose: torch.tensor, optional, shape Bx15x3x3 |
| If given, contains the pose of the left hand. |
| It should be a tensor that contains joint rotations in |
| rotation matrix format. (default=None) |
| right_hand_pose: torch.tensor, optional, shape Bx15x3x3 |
| If given, contains the pose of the right hand. |
| It should be a tensor that contains joint rotations in |
| rotation matrix format. (default=None) |
| transl: torch.tensor, optional, shape Bx3 |
| Translation vector of the body. |
| For example, it can used if the translation |
| `transl` is predicted from some external model. |
| (default=None) |
| return_verts: bool, optional |
| Return the vertices. (default=True) |
| return_full_pose: bool, optional |
| Returns the full axis-angle pose vector (default=False) |
| |
| Returns |
| ------- |
| ''' |
| model_vars = [ |
| betas, global_orient, body_pose, transl, left_hand_pose, |
| right_hand_pose |
| ] |
| batch_size = 1 |
| for var in model_vars: |
| if var is None: |
| continue |
| batch_size = max(batch_size, len(var)) |
| device, dtype = self.shapedirs.device, self.shapedirs.dtype |
| if global_orient is None: |
| global_orient = torch.eye(3, device=device, dtype=dtype).view( |
| 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() |
| if body_pose is None: |
| body_pose = torch.eye(3, device=device, dtype=dtype).view( |
| 1, 1, 3, 3).expand(batch_size, 21, -1, -1).contiguous() |
| if left_hand_pose is None: |
| left_hand_pose = torch.eye(3, device=device, dtype=dtype).view( |
| 1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous() |
| if right_hand_pose is None: |
| right_hand_pose = torch.eye(3, device=device, dtype=dtype).view( |
| 1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous() |
| if betas is None: |
| betas = torch.zeros([batch_size, self.num_betas], |
| dtype=dtype, |
| device=device) |
| if transl is None: |
| transl = torch.zeros([batch_size, 3], dtype=dtype, device=device) |
|
|
| |
| full_pose = torch.cat([ |
| global_orient.reshape(-1, 1, 3, 3), |
| body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3), |
| left_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3), |
| right_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3) |
| ], |
| dim=1) |
|
|
| vertices, joints = lbs(betas, |
| full_pose, |
| self.v_template, |
| self.shapedirs, |
| self.posedirs, |
| self.J_regressor, |
| self.parents, |
| self.lbs_weights, |
| pose2rot=False) |
|
|
| |
| joints = self.vertex_joint_selector(vertices, joints) |
| if self.joint_mapper is not None: |
| joints = self.joint_mapper(joints) |
|
|
| if transl is not None: |
| joints += transl.unsqueeze(dim=1) |
| vertices += transl.unsqueeze(dim=1) |
|
|
| output = SMPLHOutput(vertices=vertices if return_verts else None, |
| joints=joints, |
| betas=betas, |
| global_orient=global_orient, |
| body_pose=body_pose, |
| left_hand_pose=left_hand_pose, |
| right_hand_pose=right_hand_pose, |
| full_pose=full_pose if return_full_pose else None) |
|
|
| return output |
|
|
|
|
| class SMPLX(SMPLH): |
| ''' |
| SMPL-X (SMPL eXpressive) is a unified body model, with shape parameters |
| trained jointly for the face, hands and body. |
| SMPL-X uses standard vertex based linear blend skinning with learned |
| corrective blend shapes, has N=10475 vertices and K=54 joints, |
| which includes joints for the neck, jaw, eyeballs and fingers. |
| ''' |
|
|
| NUM_BODY_JOINTS = SMPLH.NUM_BODY_JOINTS |
| NUM_HAND_JOINTS = 15 |
| NUM_FACE_JOINTS = 3 |
| NUM_JOINTS = NUM_BODY_JOINTS + 2 * NUM_HAND_JOINTS + NUM_FACE_JOINTS |
| EXPRESSION_SPACE_DIM = 100 |
| NECK_IDX = 12 |
|
|
| def __init__(self, |
| model_path: str, |
| kid_template_path: str = '', |
| num_expression_coeffs: int = 10, |
| create_expression: bool = True, |
| expression: Optional[Tensor] = None, |
| create_jaw_pose: bool = True, |
| jaw_pose: Optional[Tensor] = None, |
| create_leye_pose: bool = True, |
| leye_pose: Optional[Tensor] = None, |
| create_reye_pose=True, |
| reye_pose: Optional[Tensor] = None, |
| use_face_contour: bool = False, |
| batch_size: int = 1, |
| gender: str = 'neutral', |
| age: str = 'adult', |
| dtype=torch.float32, |
| ext: str = 'npz', |
| **kwargs) -> None: |
| ''' SMPLX model constructor |
| |
| Parameters |
| ---------- |
| model_path: str |
| The path to the folder or to the file where the model |
| parameters are stored |
| num_expression_coeffs: int, optional |
| Number of expression components to use |
| (default = 10). |
| create_expression: bool, optional |
| Flag for creating a member variable for the expression space |
| (default = True). |
| expression: torch.tensor, optional, Bx10 |
| The default value for the expression member variable. |
| (default = None) |
| create_jaw_pose: bool, optional |
| Flag for creating a member variable for the jaw pose. |
| (default = False) |
| jaw_pose: torch.tensor, optional, Bx3 |
| The default value for the jaw pose variable. |
| (default = None) |
| create_leye_pose: bool, optional |
| Flag for creating a member variable for the left eye pose. |
| (default = False) |
| leye_pose: torch.tensor, optional, Bx10 |
| The default value for the left eye pose variable. |
| (default = None) |
| create_reye_pose: bool, optional |
| Flag for creating a member variable for the right eye pose. |
| (default = False) |
| reye_pose: torch.tensor, optional, Bx10 |
| The default value for the right eye pose variable. |
| (default = None) |
| use_face_contour: bool, optional |
| Whether to compute the keypoints that form the facial contour |
| batch_size: int, optional |
| The batch size used for creating the member variables |
| gender: str, optional |
| Which gender to load |
| dtype: torch.dtype |
| The data type for the created variables |
| ''' |
|
|
| |
| if osp.isdir(model_path): |
| model_fn = 'SMPLX_{}.{ext}'.format(gender.upper(), ext=ext) |
| smplx_path = os.path.join(model_path, model_fn) |
| else: |
| smplx_path = model_path |
| assert osp.exists(smplx_path), 'Path {} does not exist!'.format( |
| smplx_path) |
|
|
| if ext == 'pkl': |
| with open(smplx_path, 'rb') as smplx_file: |
| model_data = pickle.load(smplx_file, encoding='latin1') |
| elif ext == 'npz': |
| model_data = np.load(smplx_path, allow_pickle=True) |
| else: |
| raise ValueError('Unknown extension: {}'.format(ext)) |
|
|
| data_struct = Struct(**model_data) |
|
|
| super(SMPLX, self).__init__(model_path=model_path, |
| kid_template_path=kid_template_path, |
| data_struct=data_struct, |
| dtype=dtype, |
| batch_size=batch_size, |
| vertex_ids=VERTEX_IDS['smplx'], |
| gender=gender, |
| age=age, |
| ext=ext, |
| **kwargs) |
|
|
| lmk_faces_idx = data_struct.lmk_faces_idx |
| self.register_buffer('lmk_faces_idx', |
| torch.tensor(lmk_faces_idx, dtype=torch.long)) |
| lmk_bary_coords = data_struct.lmk_bary_coords |
| self.register_buffer('lmk_bary_coords', |
| torch.tensor(lmk_bary_coords, dtype=dtype)) |
|
|
| self.use_face_contour = use_face_contour |
| if self.use_face_contour: |
| dynamic_lmk_faces_idx = data_struct.dynamic_lmk_faces_idx |
| dynamic_lmk_faces_idx = torch.tensor(dynamic_lmk_faces_idx, |
| dtype=torch.long) |
| self.register_buffer('dynamic_lmk_faces_idx', |
| dynamic_lmk_faces_idx) |
|
|
| dynamic_lmk_bary_coords = data_struct.dynamic_lmk_bary_coords |
| dynamic_lmk_bary_coords = torch.tensor(dynamic_lmk_bary_coords, |
| dtype=dtype) |
| self.register_buffer('dynamic_lmk_bary_coords', |
| dynamic_lmk_bary_coords) |
|
|
| neck_kin_chain = find_joint_kin_chain(self.NECK_IDX, self.parents) |
| self.register_buffer( |
| 'neck_kin_chain', torch.tensor(neck_kin_chain, |
| dtype=torch.long)) |
|
|
| if create_jaw_pose: |
| if jaw_pose is None: |
| default_jaw_pose = torch.zeros([batch_size, 3], dtype=dtype) |
| else: |
| default_jaw_pose = torch.tensor(jaw_pose, dtype=dtype) |
| jaw_pose_param = nn.Parameter(default_jaw_pose, requires_grad=True) |
| self.register_parameter('jaw_pose', jaw_pose_param) |
|
|
| if create_leye_pose: |
| if leye_pose is None: |
| default_leye_pose = torch.zeros([batch_size, 3], dtype=dtype) |
| else: |
| default_leye_pose = torch.tensor(leye_pose, dtype=dtype) |
| leye_pose_param = nn.Parameter(default_leye_pose, |
| requires_grad=True) |
| self.register_parameter('leye_pose', leye_pose_param) |
|
|
| if create_reye_pose: |
| if reye_pose is None: |
| default_reye_pose = torch.zeros([batch_size, 3], dtype=dtype) |
| else: |
| default_reye_pose = torch.tensor(reye_pose, dtype=dtype) |
| reye_pose_param = nn.Parameter(default_reye_pose, |
| requires_grad=True) |
| self.register_parameter('reye_pose', reye_pose_param) |
|
|
| shapedirs = data_struct.shapedirs |
| if len(shapedirs.shape) < 3: |
| shapedirs = shapedirs[:, :, None] |
| if (shapedirs.shape[-1] < |
| self.SHAPE_SPACE_DIM + self.EXPRESSION_SPACE_DIM): |
| |
| |
| expr_start_idx = 10 |
| expr_end_idx = 20 |
| num_expression_coeffs = min(num_expression_coeffs, 10) |
| else: |
| expr_start_idx = self.SHAPE_SPACE_DIM |
| expr_end_idx = self.SHAPE_SPACE_DIM + num_expression_coeffs |
| num_expression_coeffs = min(num_expression_coeffs, |
| self.EXPRESSION_SPACE_DIM) |
|
|
| self._num_expression_coeffs = num_expression_coeffs |
|
|
| expr_dirs = shapedirs[:, :, expr_start_idx:expr_end_idx] |
| self.register_buffer('expr_dirs', |
| to_tensor(to_np(expr_dirs), dtype=dtype)) |
|
|
| if create_expression: |
| if expression is None: |
| default_expression = torch.zeros( |
| [batch_size, self.num_expression_coeffs], dtype=dtype) |
| else: |
| default_expression = torch.tensor(expression, dtype=dtype) |
| expression_param = nn.Parameter(default_expression, |
| requires_grad=True) |
| self.register_parameter('expression', expression_param) |
|
|
| def name(self) -> str: |
| return 'SMPL-X' |
|
|
| @property |
| def num_expression_coeffs(self): |
| return self._num_expression_coeffs |
|
|
| def create_mean_pose(self, data_struct, flat_hand_mean=False): |
| |
| |
| global_orient_mean = torch.zeros([3], dtype=self.dtype) |
| body_pose_mean = torch.zeros([self.NUM_BODY_JOINTS * 3], |
| dtype=self.dtype) |
| jaw_pose_mean = torch.zeros([3], dtype=self.dtype) |
| leye_pose_mean = torch.zeros([3], dtype=self.dtype) |
| reye_pose_mean = torch.zeros([3], dtype=self.dtype) |
|
|
| pose_mean = np.concatenate([ |
| global_orient_mean, body_pose_mean, jaw_pose_mean, leye_pose_mean, |
| reye_pose_mean, self.left_hand_mean, self.right_hand_mean |
| ], |
| axis=0) |
|
|
| return pose_mean |
|
|
| def extra_repr(self): |
| msg = super(SMPLX, self).extra_repr() |
| msg = [ |
| msg, |
| f'Number of Expression Coefficients: {self.num_expression_coeffs}' |
| ] |
| return '\n'.join(msg) |
|
|
| def forward(self, |
| betas: Optional[Tensor] = None, |
| global_orient: Optional[Tensor] = None, |
| body_pose: Optional[Tensor] = None, |
| left_hand_pose: Optional[Tensor] = None, |
| right_hand_pose: Optional[Tensor] = None, |
| transl: Optional[Tensor] = None, |
| expression: Optional[Tensor] = None, |
| jaw_pose: Optional[Tensor] = None, |
| leye_pose: Optional[Tensor] = None, |
| reye_pose: Optional[Tensor] = None, |
| return_verts: bool = True, |
| return_full_pose: bool = False, |
| pose2rot: bool = True, |
| return_joint_transformation: bool = False, |
| return_vertex_transformation: bool = False, |
| **kwargs) -> SMPLXOutput: |
| ''' |
| Forward pass for the SMPLX model |
| |
| Parameters |
| ---------- |
| global_orient: torch.tensor, optional, shape Bx3 |
| If given, ignore the member variable and use it as the global |
| rotation of the body. Useful if someone wishes to predicts this |
| with an external model. (default=None) |
| betas: torch.tensor, optional, shape BxN_b |
| If given, ignore the member variable `betas` and use it |
| instead. For example, it can used if shape parameters |
| `betas` are predicted from some external model. |
| (default=None) |
| expression: torch.tensor, optional, shape BxN_e |
| If given, ignore the member variable `expression` and use it |
| instead. For example, it can used if expression parameters |
| `expression` are predicted from some external model. |
| body_pose: torch.tensor, optional, shape Bx(J*3) |
| If given, ignore the member variable `body_pose` and use it |
| instead. For example, it can used if someone predicts the |
| pose of the body joints are predicted from some external model. |
| It should be a tensor that contains joint rotations in |
| axis-angle format. (default=None) |
| left_hand_pose: torch.tensor, optional, shape BxP |
| If given, ignore the member variable `left_hand_pose` and |
| use this instead. It should either contain PCA coefficients or |
| joint rotations in axis-angle format. |
| right_hand_pose: torch.tensor, optional, shape BxP |
| If given, ignore the member variable `right_hand_pose` and |
| use this instead. It should either contain PCA coefficients or |
| joint rotations in axis-angle format. |
| jaw_pose: torch.tensor, optional, shape Bx3 |
| If given, ignore the member variable `jaw_pose` and |
| use this instead. It should either joint rotations in |
| axis-angle format. |
| transl: torch.tensor, optional, shape Bx3 |
| If given, ignore the member variable `transl` and use it |
| instead. For example, it can used if the translation |
| `transl` is predicted from some external model. |
| (default=None) |
| return_verts: bool, optional |
| Return the vertices. (default=True) |
| return_full_pose: bool, optional |
| Returns the full axis-angle pose vector (default=False) |
| |
| Returns |
| ------- |
| output: ModelOutput |
| A named tuple of type `ModelOutput` |
| ''' |
|
|
| |
| |
| global_orient = (global_orient |
| if global_orient is not None else self.global_orient) |
| body_pose = body_pose if body_pose is not None else self.body_pose |
| betas = betas if betas is not None else self.betas |
|
|
| left_hand_pose = (left_hand_pose if left_hand_pose is not None else |
| self.left_hand_pose) |
| right_hand_pose = (right_hand_pose if right_hand_pose is not None else |
| self.right_hand_pose) |
| jaw_pose = jaw_pose if jaw_pose is not None else self.jaw_pose |
| leye_pose = leye_pose if leye_pose is not None else self.leye_pose |
| reye_pose = reye_pose if reye_pose is not None else self.reye_pose |
| expression = expression if expression is not None else self.expression |
|
|
| apply_trans = transl is not None or hasattr(self, 'transl') |
| if transl is None: |
| if hasattr(self, 'transl'): |
| transl = self.transl |
|
|
| if self.use_pca: |
| left_hand_pose = torch.einsum( |
| 'bi,ij->bj', [left_hand_pose, self.left_hand_components]) |
| right_hand_pose = torch.einsum( |
| 'bi,ij->bj', [right_hand_pose, self.right_hand_components]) |
|
|
| full_pose = torch.cat([ |
| global_orient, body_pose, jaw_pose, leye_pose, reye_pose, |
| left_hand_pose, right_hand_pose |
| ], |
| dim=1) |
|
|
| |
| |
| full_pose += self.pose_mean |
|
|
| batch_size = max(betas.shape[0], global_orient.shape[0], |
| body_pose.shape[0]) |
| |
| scale = int(batch_size / betas.shape[0]) |
| if scale > 1: |
| betas = betas.expand(scale, -1) |
| shape_components = torch.cat([betas, expression], dim=-1) |
|
|
| shapedirs = torch.cat([self.shapedirs, self.expr_dirs], dim=-1) |
|
|
| if return_joint_transformation or return_vertex_transformation: |
| vertices, joints, joint_transformation, vertex_transformation = lbs( |
| shape_components, |
| full_pose, |
| self.v_template, |
| shapedirs, |
| self.posedirs, |
| self.J_regressor, |
| self.parents, |
| self.lbs_weights, |
| pose2rot=pose2rot, |
| return_transformation=True) |
| else: |
| vertices, joints = lbs( |
| shape_components, |
| full_pose, |
| self.v_template, |
| shapedirs, |
| self.posedirs, |
| self.J_regressor, |
| self.parents, |
| self.lbs_weights, |
| pose2rot=pose2rot, |
| ) |
|
|
| lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).expand( |
| batch_size, -1).contiguous() |
| lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).repeat( |
| self.batch_size, 1, 1) |
| if self.use_face_contour: |
| lmk_idx_and_bcoords = find_dynamic_lmk_idx_and_bcoords( |
| vertices, |
| full_pose, |
| self.dynamic_lmk_faces_idx, |
| self.dynamic_lmk_bary_coords, |
| self.neck_kin_chain, |
| pose2rot=True, |
| ) |
| dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords |
|
|
| lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1) |
| lmk_bary_coords = torch.cat([ |
| lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords |
| ], 1) |
|
|
| landmarks = vertices2landmarks(vertices, self.faces_tensor, |
| lmk_faces_idx, lmk_bary_coords) |
|
|
| |
| joints = self.vertex_joint_selector(vertices, joints) |
| |
| joints = torch.cat([joints, landmarks], dim=1) |
| |
|
|
| if self.joint_mapper is not None: |
| joints = self.joint_mapper(joints=joints, vertices=vertices) |
|
|
| if apply_trans: |
| joints += transl.unsqueeze(dim=1) |
| vertices += transl.unsqueeze(dim=1) |
|
|
| output = SMPLXOutput(vertices=vertices if return_verts else None, |
| joints=joints, |
| betas=betas, |
| expression=expression, |
| global_orient=global_orient, |
| body_pose=body_pose, |
| left_hand_pose=left_hand_pose, |
| right_hand_pose=right_hand_pose, |
| jaw_pose=jaw_pose, |
| full_pose=full_pose if return_full_pose else None, |
| joint_transformation=joint_transformation |
| if return_joint_transformation else None, |
| vertex_transformation=vertex_transformation |
| if return_vertex_transformation else None) |
| return output |
|
|
|
|
| class SMPLXLayer(SMPLX): |
|
|
| def __init__(self, *args, **kwargs) -> None: |
| |
| super(SMPLXLayer, self).__init__( |
| create_global_orient=False, |
| create_body_pose=False, |
| create_left_hand_pose=False, |
| create_right_hand_pose=False, |
| create_jaw_pose=False, |
| create_leye_pose=False, |
| create_reye_pose=False, |
| create_betas=False, |
| create_expression=False, |
| create_transl=False, |
| *args, |
| **kwargs, |
| ) |
|
|
| def forward(self, |
| betas: Optional[Tensor] = None, |
| global_orient: Optional[Tensor] = None, |
| body_pose: Optional[Tensor] = None, |
| left_hand_pose: Optional[Tensor] = None, |
| right_hand_pose: Optional[Tensor] = None, |
| transl: Optional[Tensor] = None, |
| expression: Optional[Tensor] = None, |
| jaw_pose: Optional[Tensor] = None, |
| leye_pose: Optional[Tensor] = None, |
| reye_pose: Optional[Tensor] = None, |
| return_verts: bool = True, |
| return_full_pose: bool = False, |
| **kwargs) -> SMPLXOutput: |
| ''' |
| Forward pass for the SMPLX model |
| |
| Parameters |
| ---------- |
| global_orient: torch.tensor, optional, shape Bx3x3 |
| If given, ignore the member variable and use it as the global |
| rotation of the body. Useful if someone wishes to predicts this |
| with an external model. It is expected to be in rotation matrix |
| format. (default=None) |
| betas: torch.tensor, optional, shape BxN_b |
| If given, ignore the member variable `betas` and use it |
| instead. For example, it can used if shape parameters |
| `betas` are predicted from some external model. |
| (default=None) |
| expression: torch.tensor, optional, shape BxN_e |
| Expression coefficients. |
| For example, it can used if expression parameters |
| `expression` are predicted from some external model. |
| body_pose: torch.tensor, optional, shape BxJx3x3 |
| If given, ignore the member variable `body_pose` and use it |
| instead. For example, it can used if someone predicts the |
| pose of the body joints are predicted from some external model. |
| It should be a tensor that contains joint rotations in |
| rotation matrix format. (default=None) |
| left_hand_pose: torch.tensor, optional, shape Bx15x3x3 |
| If given, contains the pose of the left hand. |
| It should be a tensor that contains joint rotations in |
| rotation matrix format. (default=None) |
| right_hand_pose: torch.tensor, optional, shape Bx15x3x3 |
| If given, contains the pose of the right hand. |
| It should be a tensor that contains joint rotations in |
| rotation matrix format. (default=None) |
| jaw_pose: torch.tensor, optional, shape Bx3x3 |
| Jaw pose. It should either joint rotations in |
| rotation matrix format. |
| transl: torch.tensor, optional, shape Bx3 |
| Translation vector of the body. |
| For example, it can used if the translation |
| `transl` is predicted from some external model. |
| (default=None) |
| return_verts: bool, optional |
| Return the vertices. (default=True) |
| return_full_pose: bool, optional |
| Returns the full pose vector (default=False) |
| Returns |
| ------- |
| output: ModelOutput |
| A data class that contains the posed vertices and joints |
| ''' |
| device, dtype = self.shapedirs.device, self.shapedirs.dtype |
|
|
| model_vars = [ |
| betas, global_orient, body_pose, transl, expression, |
| left_hand_pose, right_hand_pose, jaw_pose |
| ] |
| batch_size = 1 |
| for var in model_vars: |
| if var is None: |
| continue |
| batch_size = max(batch_size, len(var)) |
|
|
| if global_orient is None: |
| global_orient = torch.eye(3, device=device, dtype=dtype).view( |
| 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() |
| if body_pose is None: |
| body_pose = torch.eye(3, device=device, dtype=dtype).view( |
| 1, 1, 3, 3).expand(batch_size, self.NUM_BODY_JOINTS, -1, |
| -1).contiguous() |
| if left_hand_pose is None: |
| left_hand_pose = torch.eye(3, device=device, dtype=dtype).view( |
| 1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous() |
| if right_hand_pose is None: |
| right_hand_pose = torch.eye(3, device=device, dtype=dtype).view( |
| 1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous() |
| if jaw_pose is None: |
| jaw_pose = torch.eye(3, device=device, dtype=dtype).view( |
| 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() |
| if leye_pose is None: |
| leye_pose = torch.eye(3, device=device, dtype=dtype).view( |
| 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() |
| if reye_pose is None: |
| reye_pose = torch.eye(3, device=device, dtype=dtype).view( |
| 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() |
| if expression is None: |
| expression = torch.zeros([batch_size, self.num_expression_coeffs], |
| dtype=dtype, |
| device=device) |
| if betas is None: |
| betas = torch.zeros([batch_size, self.num_betas], |
| dtype=dtype, |
| device=device) |
| if transl is None: |
| transl = torch.zeros([batch_size, 3], dtype=dtype, device=device) |
|
|
| |
| full_pose = torch.cat([ |
| global_orient.reshape(-1, 1, 3, 3), |
| body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3), |
| jaw_pose.reshape(-1, 1, 3, 3), |
| leye_pose.reshape(-1, 1, 3, 3), |
| reye_pose.reshape(-1, 1, 3, 3), |
| left_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3), |
| right_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3) |
| ], |
| dim=1) |
| shape_components = torch.cat([betas, expression], dim=-1) |
|
|
| shapedirs = torch.cat([self.shapedirs, self.expr_dirs], dim=-1) |
|
|
| vertices, joints = lbs( |
| shape_components, |
| full_pose, |
| self.v_template, |
| shapedirs, |
| self.posedirs, |
| self.J_regressor, |
| self.parents, |
| self.lbs_weights, |
| pose2rot=False, |
| ) |
|
|
| lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).expand( |
| batch_size, -1).contiguous() |
| lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).repeat( |
| batch_size, 1, 1) |
| if self.use_face_contour: |
| lmk_idx_and_bcoords = find_dynamic_lmk_idx_and_bcoords( |
| vertices, |
| full_pose, |
| self.dynamic_lmk_faces_idx, |
| self.dynamic_lmk_bary_coords, |
| self.neck_kin_chain, |
| pose2rot=False, |
| ) |
| dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords |
|
|
| lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1) |
| lmk_bary_coords = torch.cat([ |
| lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords |
| ], 1) |
|
|
| landmarks = vertices2landmarks(vertices, self.faces_tensor, |
| lmk_faces_idx, lmk_bary_coords) |
|
|
| |
| joints = self.vertex_joint_selector(vertices, joints) |
| |
| joints = torch.cat([joints, landmarks], dim=1) |
| |
|
|
| if self.joint_mapper is not None: |
| joints = self.joint_mapper(joints=joints, vertices=vertices) |
|
|
| if transl is not None: |
| joints += transl.unsqueeze(dim=1) |
| vertices += transl.unsqueeze(dim=1) |
|
|
| output = SMPLXOutput(vertices=vertices if return_verts else None, |
| joints=joints, |
| betas=betas, |
| expression=expression, |
| global_orient=global_orient, |
| body_pose=body_pose, |
| left_hand_pose=left_hand_pose, |
| right_hand_pose=right_hand_pose, |
| jaw_pose=jaw_pose, |
| transl=transl, |
| full_pose=full_pose if return_full_pose else None) |
| return output |
|
|
|
|
| class MANO(SMPL): |
| |
| NUM_BODY_JOINTS = 1 |
| NUM_HAND_JOINTS = 15 |
| NUM_JOINTS = NUM_BODY_JOINTS + NUM_HAND_JOINTS |
|
|
| def __init__(self, |
| model_path: str, |
| is_rhand: bool = True, |
| data_struct: Optional[Struct] = None, |
| create_hand_pose: bool = True, |
| hand_pose: Optional[Tensor] = None, |
| use_pca: bool = True, |
| num_pca_comps: int = 6, |
| flat_hand_mean: bool = False, |
| batch_size: int = 1, |
| dtype=torch.float32, |
| vertex_ids=None, |
| use_compressed: bool = True, |
| ext: str = 'pkl', |
| **kwargs) -> None: |
| ''' MANO model constructor |
| |
| Parameters |
| ---------- |
| model_path: str |
| The path to the folder or to the file where the model |
| parameters are stored |
| data_struct: Strct |
| A struct object. If given, then the parameters of the model are |
| read from the object. Otherwise, the model tries to read the |
| parameters from the given `model_path`. (default = None) |
| create_hand_pose: bool, optional |
| Flag for creating a member variable for the pose of the right |
| hand. (default = True) |
| hand_pose: torch.tensor, optional, BxP |
| The default value for the right hand pose member variable. |
| (default = None) |
| num_pca_comps: int, optional |
| The number of PCA components to use for each hand. |
| (default = 6) |
| flat_hand_mean: bool, optional |
| If False, then the pose of the hand is initialized to False. |
| batch_size: int, optional |
| The batch size used for creating the member variables |
| dtype: torch.dtype, optional |
| The data type for the created variables |
| vertex_ids: dict, optional |
| A dictionary containing the indices of the extra vertices that |
| will be selected |
| ''' |
|
|
| self.num_pca_comps = num_pca_comps |
| self.is_rhand = is_rhand |
| |
| |
| if data_struct is None: |
| |
| if osp.isdir(model_path): |
| model_fn = 'MANO_{}.{ext}'.format( |
| 'RIGHT' if is_rhand else 'LEFT', ext=ext) |
| mano_path = os.path.join(model_path, model_fn) |
| else: |
| mano_path = model_path |
| self.is_rhand = True if 'RIGHT' in os.path.basename( |
| model_path) else False |
| assert osp.exists(mano_path), 'Path {} does not exist!'.format( |
| mano_path) |
|
|
| if ext == 'pkl': |
| with open(mano_path, 'rb') as mano_file: |
| model_data = pickle.load(mano_file, encoding='latin1') |
| elif ext == 'npz': |
| model_data = np.load(mano_path, allow_pickle=True) |
| else: |
| raise ValueError('Unknown extension: {}'.format(ext)) |
| data_struct = Struct(**model_data) |
|
|
| if vertex_ids is None: |
| vertex_ids = VERTEX_IDS['smplh'] |
|
|
| super(MANO, self).__init__(model_path=model_path, |
| data_struct=data_struct, |
| batch_size=batch_size, |
| vertex_ids=vertex_ids, |
| use_compressed=use_compressed, |
| dtype=dtype, |
| ext=ext, |
| **kwargs) |
|
|
| |
| self.vertex_joint_selector.extra_joints_idxs = to_tensor( |
| list(VERTEX_IDS['mano'].values()), dtype=torch.long) |
|
|
| self.use_pca = use_pca |
| self.num_pca_comps = num_pca_comps |
| if self.num_pca_comps == 45: |
| self.use_pca = False |
| self.flat_hand_mean = flat_hand_mean |
|
|
| hand_components = data_struct.hands_components[:num_pca_comps] |
|
|
| self.np_hand_components = hand_components |
|
|
| if self.use_pca: |
| self.register_buffer('hand_components', |
| torch.tensor(hand_components, dtype=dtype)) |
|
|
| if self.flat_hand_mean: |
| hand_mean = np.zeros_like(data_struct.hands_mean) |
| else: |
| hand_mean = data_struct.hands_mean |
|
|
| self.register_buffer('hand_mean', to_tensor(hand_mean, |
| dtype=self.dtype)) |
|
|
| |
| hand_pose_dim = num_pca_comps if use_pca else 3 * self.NUM_HAND_JOINTS |
| if create_hand_pose: |
| if hand_pose is None: |
| default_hand_pose = torch.zeros([batch_size, hand_pose_dim], |
| dtype=dtype) |
| else: |
| default_hand_pose = torch.tensor(hand_pose, dtype=dtype) |
|
|
| hand_pose_param = nn.Parameter(default_hand_pose, |
| requires_grad=True) |
| self.register_parameter('hand_pose', hand_pose_param) |
|
|
| |
| pose_mean = self.create_mean_pose(data_struct, |
| flat_hand_mean=flat_hand_mean) |
| pose_mean_tensor = pose_mean.clone().to(dtype) |
| |
| self.register_buffer('pose_mean', pose_mean_tensor) |
|
|
| def name(self) -> str: |
| return 'MANO' |
|
|
| def create_mean_pose(self, data_struct, flat_hand_mean=False): |
| |
| |
| global_orient_mean = torch.zeros([3], dtype=self.dtype) |
| pose_mean = torch.cat([global_orient_mean, self.hand_mean], dim=0) |
| return pose_mean |
|
|
| def extra_repr(self): |
| msg = [super(MANO, self).extra_repr()] |
| if self.use_pca: |
| msg.append(f'Number of PCA components: {self.num_pca_comps}') |
| msg.append(f'Flat hand mean: {self.flat_hand_mean}') |
| return '\n'.join(msg) |
|
|
| def forward(self, |
| betas: Optional[Tensor] = None, |
| global_orient: Optional[Tensor] = None, |
| hand_pose: Optional[Tensor] = None, |
| transl: Optional[Tensor] = None, |
| return_verts: bool = True, |
| return_full_pose: bool = False, |
| **kwargs) -> MANOOutput: |
| ''' Forward pass for the MANO model |
| ''' |
| |
| |
| global_orient = (global_orient |
| if global_orient is not None else self.global_orient) |
| betas = betas if betas is not None else self.betas |
| hand_pose = (hand_pose if hand_pose is not None else self.hand_pose) |
|
|
| apply_trans = transl is not None or hasattr(self, 'transl') |
| if transl is None: |
| if hasattr(self, 'transl'): |
| transl = self.transl |
|
|
| if self.use_pca: |
| hand_pose = torch.einsum('bi,ij->bj', |
| [hand_pose, self.hand_components]) |
|
|
| full_pose = torch.cat([global_orient, hand_pose], dim=1) |
| full_pose += self.pose_mean |
|
|
| vertices, joints = lbs( |
| betas, |
| full_pose, |
| self.v_template, |
| self.shapedirs, |
| self.posedirs, |
| self.J_regressor, |
| self.parents, |
| self.lbs_weights, |
| pose2rot=True, |
| ) |
|
|
| |
| |
|
|
| if self.joint_mapper is not None: |
| joints = self.joint_mapper(joints) |
|
|
| if apply_trans: |
| joints = joints + transl.unsqueeze(dim=1) |
| vertices = vertices + transl.unsqueeze(dim=1) |
|
|
| output = MANOOutput(vertices=vertices if return_verts else None, |
| joints=joints if return_verts else None, |
| betas=betas, |
| global_orient=global_orient, |
| hand_pose=hand_pose, |
| full_pose=full_pose if return_full_pose else None) |
|
|
| return output |
|
|
|
|
| class MANOLayer(MANO): |
|
|
| def __init__(self, *args, **kwargs) -> None: |
| ''' MANO as a layer model constructor |
| ''' |
| super(MANOLayer, self).__init__(create_global_orient=False, |
| create_hand_pose=False, |
| create_betas=False, |
| create_transl=False, |
| *args, |
| **kwargs) |
|
|
| def name(self) -> str: |
| return 'MANO' |
|
|
| def forward(self, |
| betas: Optional[Tensor] = None, |
| global_orient: Optional[Tensor] = None, |
| hand_pose: Optional[Tensor] = None, |
| transl: Optional[Tensor] = None, |
| return_verts: bool = True, |
| return_full_pose: bool = False, |
| **kwargs) -> MANOOutput: |
| ''' Forward pass for the MANO model |
| ''' |
| device, dtype = self.shapedirs.device, self.shapedirs.dtype |
| if global_orient is None: |
| batch_size = 1 |
| global_orient = torch.eye(3, device=device, dtype=dtype).view( |
| 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() |
| else: |
| batch_size = global_orient.shape[0] |
| if hand_pose is None: |
| hand_pose = torch.eye(3, device=device, dtype=dtype).view( |
| 1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous() |
| if betas is None: |
| betas = torch.zeros([batch_size, self.num_betas], |
| dtype=dtype, |
| device=device) |
| if transl is None: |
| transl = torch.zeros([batch_size, 3], dtype=dtype, device=device) |
|
|
| full_pose = torch.cat([global_orient, hand_pose], dim=1) |
| vertices, joints = lbs(betas, |
| full_pose, |
| self.v_template, |
| self.shapedirs, |
| self.posedirs, |
| self.J_regressor, |
| self.parents, |
| self.lbs_weights, |
| pose2rot=False) |
|
|
| if self.joint_mapper is not None: |
| joints = self.joint_mapper(joints) |
|
|
| if transl is not None: |
| joints = joints + transl.unsqueeze(dim=1) |
| vertices = vertices + transl.unsqueeze(dim=1) |
|
|
| output = MANOOutput(vertices=vertices if return_verts else None, |
| joints=joints if return_verts else None, |
| betas=betas, |
| global_orient=global_orient, |
| hand_pose=hand_pose, |
| full_pose=full_pose if return_full_pose else None) |
|
|
| return output |
|
|
|
|
| class FLAME(SMPL): |
| NUM_JOINTS = 5 |
| SHAPE_SPACE_DIM = 300 |
| EXPRESSION_SPACE_DIM = 100 |
| NECK_IDX = 0 |
|
|
| def __init__(self, |
| model_path: str, |
| data_struct=None, |
| num_expression_coeffs=10, |
| create_expression: bool = True, |
| expression: Optional[Tensor] = None, |
| create_neck_pose: bool = True, |
| neck_pose: Optional[Tensor] = None, |
| create_jaw_pose: bool = True, |
| jaw_pose: Optional[Tensor] = None, |
| create_leye_pose: bool = True, |
| leye_pose: Optional[Tensor] = None, |
| create_reye_pose=True, |
| reye_pose: Optional[Tensor] = None, |
| use_face_contour=False, |
| batch_size: int = 1, |
| gender: str = 'neutral', |
| dtype: torch.dtype = torch.float32, |
| ext='pkl', |
| **kwargs) -> None: |
| ''' FLAME model constructor |
| |
| Parameters |
| ---------- |
| model_path: str |
| The path to the folder or to the file where the model |
| parameters are stored |
| num_expression_coeffs: int, optional |
| Number of expression components to use |
| (default = 10). |
| create_expression: bool, optional |
| Flag for creating a member variable for the expression space |
| (default = True). |
| expression: torch.tensor, optional, Bx10 |
| The default value for the expression member variable. |
| (default = None) |
| create_neck_pose: bool, optional |
| Flag for creating a member variable for the neck pose. |
| (default = False) |
| neck_pose: torch.tensor, optional, Bx3 |
| The default value for the neck pose variable. |
| (default = None) |
| create_jaw_pose: bool, optional |
| Flag for creating a member variable for the jaw pose. |
| (default = False) |
| jaw_pose: torch.tensor, optional, Bx3 |
| The default value for the jaw pose variable. |
| (default = None) |
| create_leye_pose: bool, optional |
| Flag for creating a member variable for the left eye pose. |
| (default = False) |
| leye_pose: torch.tensor, optional, Bx10 |
| The default value for the left eye pose variable. |
| (default = None) |
| create_reye_pose: bool, optional |
| Flag for creating a member variable for the right eye pose. |
| (default = False) |
| reye_pose: torch.tensor, optional, Bx10 |
| The default value for the right eye pose variable. |
| (default = None) |
| use_face_contour: bool, optional |
| Whether to compute the keypoints that form the facial contour |
| batch_size: int, optional |
| The batch size used for creating the member variables |
| gender: str, optional |
| Which gender to load |
| dtype: torch.dtype |
| The data type for the created variables |
| ''' |
| model_fn = f'FLAME_{gender.upper()}.{ext}' |
| flame_path = os.path.join(model_path, model_fn) |
| assert osp.exists(flame_path), 'Path {} does not exist!'.format( |
| flame_path) |
| if ext == 'npz': |
| file_data = np.load(flame_path, allow_pickle=True) |
| elif ext == 'pkl': |
| with open(flame_path, 'rb') as smpl_file: |
| file_data = pickle.load(smpl_file, encoding='latin1') |
| else: |
| raise ValueError('Unknown extension: {}'.format(ext)) |
| data_struct = Struct(**file_data) |
|
|
| super(FLAME, self).__init__(model_path=model_path, |
| data_struct=data_struct, |
| dtype=dtype, |
| batch_size=batch_size, |
| gender=gender, |
| ext=ext, |
| **kwargs) |
|
|
| self.use_face_contour = use_face_contour |
|
|
| self.vertex_joint_selector.extra_joints_idxs = to_tensor( |
| [], dtype=torch.long) |
|
|
| if create_neck_pose: |
| if neck_pose is None: |
| default_neck_pose = torch.zeros([batch_size, 3], dtype=dtype) |
| else: |
| default_neck_pose = torch.tensor(neck_pose, dtype=dtype) |
| neck_pose_param = nn.Parameter(default_neck_pose, |
| requires_grad=True) |
| self.register_parameter('neck_pose', neck_pose_param) |
|
|
| if create_jaw_pose: |
| if jaw_pose is None: |
| default_jaw_pose = torch.zeros([batch_size, 3], dtype=dtype) |
| else: |
| default_jaw_pose = torch.tensor(jaw_pose, dtype=dtype) |
| jaw_pose_param = nn.Parameter(default_jaw_pose, requires_grad=True) |
| self.register_parameter('jaw_pose', jaw_pose_param) |
|
|
| if create_leye_pose: |
| if leye_pose is None: |
| default_leye_pose = torch.zeros([batch_size, 3], dtype=dtype) |
| else: |
| default_leye_pose = torch.tensor(leye_pose, dtype=dtype) |
| leye_pose_param = nn.Parameter(default_leye_pose, |
| requires_grad=True) |
| self.register_parameter('leye_pose', leye_pose_param) |
|
|
| if create_reye_pose: |
| if reye_pose is None: |
| default_reye_pose = torch.zeros([batch_size, 3], dtype=dtype) |
| else: |
| default_reye_pose = torch.tensor(reye_pose, dtype=dtype) |
| reye_pose_param = nn.Parameter(default_reye_pose, |
| requires_grad=True) |
| self.register_parameter('reye_pose', reye_pose_param) |
|
|
| shapedirs = data_struct.shapedirs |
| if len(shapedirs.shape) < 3: |
| shapedirs = shapedirs[:, :, None] |
| if (shapedirs.shape[-1] < |
| self.SHAPE_SPACE_DIM + self.EXPRESSION_SPACE_DIM): |
| |
| |
| expr_start_idx = 10 |
| expr_end_idx = 20 |
| num_expression_coeffs = min(num_expression_coeffs, 10) |
| else: |
| expr_start_idx = self.SHAPE_SPACE_DIM |
| expr_end_idx = self.SHAPE_SPACE_DIM + num_expression_coeffs |
| num_expression_coeffs = min(num_expression_coeffs, |
| self.EXPRESSION_SPACE_DIM) |
|
|
| self._num_expression_coeffs = num_expression_coeffs |
|
|
| expr_dirs = shapedirs[:, :, expr_start_idx:expr_end_idx] |
| self.register_buffer('expr_dirs', |
| to_tensor(to_np(expr_dirs), dtype=dtype)) |
|
|
| if create_expression: |
| if expression is None: |
| default_expression = torch.zeros( |
| [batch_size, self.num_expression_coeffs], dtype=dtype) |
| else: |
| default_expression = torch.tensor(expression, dtype=dtype) |
| expression_param = nn.Parameter(default_expression, |
| requires_grad=True) |
| self.register_parameter('expression', expression_param) |
|
|
| |
| |
| landmark_bcoord_filename = osp.join(model_path, |
| 'flame_static_embedding.pkl') |
|
|
| with open(landmark_bcoord_filename, 'rb') as fp: |
| landmarks_data = pickle.load(fp, encoding='latin1') |
|
|
| lmk_faces_idx = landmarks_data['lmk_face_idx'].astype(np.int64) |
| self.register_buffer('lmk_faces_idx', |
| torch.tensor(lmk_faces_idx, dtype=torch.long)) |
| lmk_bary_coords = landmarks_data['lmk_b_coords'] |
| self.register_buffer('lmk_bary_coords', |
| torch.tensor(lmk_bary_coords, dtype=dtype)) |
| if self.use_face_contour: |
| face_contour_path = os.path.join(model_path, |
| 'flame_dynamic_embedding.npy') |
| contour_embeddings = np.load(face_contour_path, |
| allow_pickle=True, |
| encoding='latin1')[()] |
|
|
| dynamic_lmk_faces_idx = np.array( |
| contour_embeddings['lmk_face_idx'], dtype=np.int64) |
| dynamic_lmk_faces_idx = torch.tensor(dynamic_lmk_faces_idx, |
| dtype=torch.long) |
| self.register_buffer('dynamic_lmk_faces_idx', |
| dynamic_lmk_faces_idx) |
|
|
| dynamic_lmk_b_coords = torch.tensor( |
| contour_embeddings['lmk_b_coords'], dtype=dtype) |
| self.register_buffer('dynamic_lmk_bary_coords', |
| dynamic_lmk_b_coords) |
|
|
| neck_kin_chain = find_joint_kin_chain(self.NECK_IDX, self.parents) |
| self.register_buffer( |
| 'neck_kin_chain', torch.tensor(neck_kin_chain, |
| dtype=torch.long)) |
|
|
| @property |
| def num_expression_coeffs(self): |
| return self._num_expression_coeffs |
|
|
| def name(self) -> str: |
| return 'FLAME' |
|
|
| def extra_repr(self): |
| msg = [ |
| super(FLAME, self).extra_repr(), |
| f'Number of Expression Coefficients: {self.num_expression_coeffs}', |
| f'Use face contour: {self.use_face_contour}', |
| ] |
| return '\n'.join(msg) |
|
|
| def forward(self, |
| betas: Optional[Tensor] = None, |
| global_orient: Optional[Tensor] = None, |
| neck_pose: Optional[Tensor] = None, |
| transl: Optional[Tensor] = None, |
| expression: Optional[Tensor] = None, |
| jaw_pose: Optional[Tensor] = None, |
| leye_pose: Optional[Tensor] = None, |
| reye_pose: Optional[Tensor] = None, |
| return_verts: bool = True, |
| return_full_pose: bool = False, |
| pose2rot: bool = True, |
| **kwargs) -> FLAMEOutput: |
| ''' |
| Forward pass for the SMPLX model |
| |
| Parameters |
| ---------- |
| global_orient: torch.tensor, optional, shape Bx3 |
| If given, ignore the member variable and use it as the global |
| rotation of the body. Useful if someone wishes to predicts this |
| with an external model. (default=None) |
| betas: torch.tensor, optional, shape Bx10 |
| If given, ignore the member variable `betas` and use it |
| instead. For example, it can used if shape parameters |
| `betas` are predicted from some external model. |
| (default=None) |
| expression: torch.tensor, optional, shape Bx10 |
| If given, ignore the member variable `expression` and use it |
| instead. For example, it can used if expression parameters |
| `expression` are predicted from some external model. |
| jaw_pose: torch.tensor, optional, shape Bx3 |
| If given, ignore the member variable `jaw_pose` and |
| use this instead. It should either joint rotations in |
| axis-angle format. |
| jaw_pose: torch.tensor, optional, shape Bx3 |
| If given, ignore the member variable `jaw_pose` and |
| use this instead. It should either joint rotations in |
| axis-angle format. |
| transl: torch.tensor, optional, shape Bx3 |
| If given, ignore the member variable `transl` and use it |
| instead. For example, it can used if the translation |
| `transl` is predicted from some external model. |
| (default=None) |
| return_verts: bool, optional |
| Return the vertices. (default=True) |
| return_full_pose: bool, optional |
| Returns the full axis-angle pose vector (default=False) |
| |
| Returns |
| ------- |
| output: ModelOutput |
| A named tuple of type `ModelOutput` |
| ''' |
|
|
| |
| |
| global_orient = (global_orient |
| if global_orient is not None else self.global_orient) |
| jaw_pose = jaw_pose if jaw_pose is not None else self.jaw_pose |
| neck_pose = neck_pose if neck_pose is not None else self.neck_pose |
|
|
| leye_pose = leye_pose if leye_pose is not None else self.leye_pose |
| reye_pose = reye_pose if reye_pose is not None else self.reye_pose |
|
|
| betas = betas if betas is not None else self.betas |
| expression = expression if expression is not None else self.expression |
|
|
| apply_trans = transl is not None or hasattr(self, 'transl') |
| if transl is None: |
| if hasattr(self, 'transl'): |
| transl = self.transl |
|
|
| full_pose = torch.cat( |
| [global_orient, neck_pose, jaw_pose, leye_pose, reye_pose], dim=1) |
|
|
| batch_size = max(betas.shape[0], global_orient.shape[0], |
| jaw_pose.shape[0]) |
| |
| scale = int(batch_size / betas.shape[0]) |
| if scale > 1: |
| betas = betas.expand(scale, -1) |
| shape_components = torch.cat([betas, expression], dim=-1) |
| shapedirs = torch.cat([self.shapedirs, self.expr_dirs], dim=-1) |
|
|
| vertices, joints = lbs( |
| shape_components, |
| full_pose, |
| self.v_template, |
| shapedirs, |
| self.posedirs, |
| self.J_regressor, |
| self.parents, |
| self.lbs_weights, |
| pose2rot=pose2rot, |
| ) |
|
|
| lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).expand( |
| batch_size, -1).contiguous() |
| lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).repeat( |
| self.batch_size, 1, 1) |
| if self.use_face_contour: |
| lmk_idx_and_bcoords = find_dynamic_lmk_idx_and_bcoords( |
| vertices, |
| full_pose, |
| self.dynamic_lmk_faces_idx, |
| self.dynamic_lmk_bary_coords, |
| self.neck_kin_chain, |
| pose2rot=True, |
| ) |
| dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords |
| lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1) |
| lmk_bary_coords = torch.cat([ |
| lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords |
| ], 1) |
|
|
| landmarks = vertices2landmarks(vertices, self.faces_tensor, |
| lmk_faces_idx, lmk_bary_coords) |
|
|
| |
| joints = self.vertex_joint_selector(vertices, joints) |
| |
| joints = torch.cat([joints, landmarks], dim=1) |
|
|
| |
| if self.joint_mapper is not None: |
| joints = self.joint_mapper(joints=joints, vertices=vertices) |
|
|
| if apply_trans: |
| joints += transl.unsqueeze(dim=1) |
| vertices += transl.unsqueeze(dim=1) |
|
|
| output = FLAMEOutput(vertices=vertices if return_verts else None, |
| joints=joints, |
| betas=betas, |
| expression=expression, |
| global_orient=global_orient, |
| neck_pose=neck_pose, |
| jaw_pose=jaw_pose, |
| full_pose=full_pose if return_full_pose else None) |
| return output |
|
|
|
|
| class FLAMELayer(FLAME): |
|
|
| def __init__(self, *args, **kwargs) -> None: |
| ''' FLAME as a layer model constructor ''' |
| super(FLAMELayer, self).__init__(create_betas=False, |
| create_expression=False, |
| create_global_orient=False, |
| create_neck_pose=False, |
| create_jaw_pose=False, |
| create_leye_pose=False, |
| create_reye_pose=False, |
| *args, |
| **kwargs) |
|
|
| def forward(self, |
| betas: Optional[Tensor] = None, |
| global_orient: Optional[Tensor] = None, |
| neck_pose: Optional[Tensor] = None, |
| transl: Optional[Tensor] = None, |
| expression: Optional[Tensor] = None, |
| jaw_pose: Optional[Tensor] = None, |
| leye_pose: Optional[Tensor] = None, |
| reye_pose: Optional[Tensor] = None, |
| return_verts: bool = True, |
| return_full_pose: bool = False, |
| pose2rot: bool = True, |
| **kwargs) -> FLAMEOutput: |
| ''' |
| Forward pass for the SMPLX model |
| |
| Parameters |
| ---------- |
| global_orient: torch.tensor, optional, shape Bx3x3 |
| Global rotation of the body. Useful if someone wishes to |
| predicts this with an external model. It is expected to be in |
| rotation matrix format. (default=None) |
| betas: torch.tensor, optional, shape BxN_b |
| Shape parameters. For example, it can used if shape parameters |
| `betas` are predicted from some external model. |
| (default=None) |
| expression: torch.tensor, optional, shape BxN_e |
| If given, ignore the member variable `expression` and use it |
| instead. For example, it can used if expression parameters |
| `expression` are predicted from some external model. |
| jaw_pose: torch.tensor, optional, shape Bx3x3 |
| Jaw pose. It should either joint rotations in |
| rotation matrix format. |
| transl: torch.tensor, optional, shape Bx3 |
| Translation vector of the body. |
| For example, it can used if the translation |
| `transl` is predicted from some external model. |
| (default=None) |
| return_verts: bool, optional |
| Return the vertices. (default=True) |
| return_full_pose: bool, optional |
| Returns the full axis-angle pose vector (default=False) |
| |
| Returns |
| ------- |
| output: ModelOutput |
| A named tuple of type `ModelOutput` |
| ''' |
| device, dtype = self.shapedirs.device, self.shapedirs.dtype |
| if global_orient is None: |
| batch_size = 1 |
| global_orient = torch.eye(3, device=device, dtype=dtype).view( |
| 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() |
| else: |
| batch_size = global_orient.shape[0] |
| if neck_pose is None: |
| neck_pose = torch.eye(3, device=device, dtype=dtype).view( |
| 1, 1, 3, 3).expand(batch_size, 1, -1, -1).contiguous() |
| if jaw_pose is None: |
| jaw_pose = torch.eye(3, device=device, dtype=dtype).view( |
| 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() |
| if leye_pose is None: |
| leye_pose = torch.eye(3, device=device, dtype=dtype).view( |
| 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() |
| if reye_pose is None: |
| reye_pose = torch.eye(3, device=device, dtype=dtype).view( |
| 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() |
| if betas is None: |
| betas = torch.zeros([batch_size, self.num_betas], |
| dtype=dtype, |
| device=device) |
| if expression is None: |
| expression = torch.zeros([batch_size, self.num_expression_coeffs], |
| dtype=dtype, |
| device=device) |
| if transl is None: |
| transl = torch.zeros([batch_size, 3], dtype=dtype, device=device) |
|
|
| full_pose = torch.cat( |
| [global_orient, neck_pose, jaw_pose, leye_pose, reye_pose], dim=1) |
|
|
| shape_components = torch.cat([betas, expression], dim=-1) |
| shapedirs = torch.cat([self.shapedirs, self.expr_dirs], dim=-1) |
|
|
| vertices, joints = lbs( |
| shape_components, |
| full_pose, |
| self.v_template, |
| shapedirs, |
| self.posedirs, |
| self.J_regressor, |
| self.parents, |
| self.lbs_weights, |
| pose2rot=False, |
| ) |
|
|
| lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).expand( |
| batch_size, -1).contiguous() |
| lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).repeat( |
| self.batch_size, 1, 1) |
| if self.use_face_contour: |
| lmk_idx_and_bcoords = find_dynamic_lmk_idx_and_bcoords( |
| vertices, |
| full_pose, |
| self.dynamic_lmk_faces_idx, |
| self.dynamic_lmk_bary_coords, |
| self.neck_kin_chain, |
| pose2rot=False, |
| ) |
| dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords |
| lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1) |
| lmk_bary_coords = torch.cat([ |
| lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords |
| ], 1) |
|
|
| landmarks = vertices2landmarks(vertices, self.faces_tensor, |
| lmk_faces_idx, lmk_bary_coords) |
|
|
| |
| joints = self.vertex_joint_selector(vertices, joints) |
| |
| joints = torch.cat([joints, landmarks], dim=1) |
|
|
| |
| if self.joint_mapper is not None: |
| joints = self.joint_mapper(joints=joints, vertices=vertices) |
|
|
| joints += transl.unsqueeze(dim=1) |
| vertices += transl.unsqueeze(dim=1) |
|
|
| output = FLAMEOutput(vertices=vertices if return_verts else None, |
| joints=joints, |
| betas=betas, |
| expression=expression, |
| global_orient=global_orient, |
| neck_pose=neck_pose, |
| jaw_pose=jaw_pose, |
| full_pose=full_pose if return_full_pose else None) |
| return output |
|
|
|
|
| def build_layer( |
| model_path: str, |
| model_type: str = 'smpl', |
| **kwargs |
| ) -> Union[SMPLLayer, SMPLHLayer, SMPLXLayer, MANOLayer, FLAMELayer]: |
| ''' Method for creating a model from a path and a model type |
| |
| Parameters |
| ---------- |
| model_path: str |
| Either the path to the model you wish to load or a folder, |
| where each subfolder contains the differents types, i.e.: |
| model_path: |
| | |
| |-- smpl |
| |-- SMPL_FEMALE |
| |-- SMPL_NEUTRAL |
| |-- SMPL_MALE |
| |-- smplh |
| |-- SMPLH_FEMALE |
| |-- SMPLH_MALE |
| |-- smplx |
| |-- SMPLX_FEMALE |
| |-- SMPLX_NEUTRAL |
| |-- SMPLX_MALE |
| |-- mano |
| |-- MANO RIGHT |
| |-- MANO LEFT |
| |-- flame |
| |-- FLAME_FEMALE |
| |-- FLAME_MALE |
| |-- FLAME_NEUTRAL |
| |
| model_type: str, optional |
| When model_path is a folder, then this parameter specifies the |
| type of model to be loaded |
| **kwargs: dict |
| Keyword arguments |
| |
| Returns |
| ------- |
| body_model: nn.Module |
| The PyTorch module that implements the corresponding body model |
| Raises |
| ------ |
| ValueError: In case the model type is not one of SMPL, SMPLH, |
| SMPLX, MANO or FLAME |
| ''' |
|
|
| if osp.isdir(model_path): |
| model_path = os.path.join(model_path, model_type) |
| else: |
| model_type = osp.basename(model_path).split('_')[0].lower() |
|
|
| if model_type.lower() == 'smpl': |
| return SMPLLayer(model_path, **kwargs) |
| elif model_type.lower() == 'smplh': |
| return SMPLHLayer(model_path, **kwargs) |
| elif model_type.lower() == 'smplx': |
| return SMPLXLayer(model_path, **kwargs) |
| elif 'mano' in model_type.lower(): |
| return MANOLayer(model_path, **kwargs) |
| elif 'flame' in model_type.lower(): |
| return FLAMELayer(model_path, **kwargs) |
| else: |
| raise ValueError(f'Unknown model type {model_type}, exiting!') |
|
|
|
|
| def create(model_path: str, |
| model_type: str = 'smpl', |
| **kwargs) -> Union[SMPL, SMPLH, SMPLX, MANO, FLAME]: |
| ''' Method for creating a model from a path and a model type |
| |
| Parameters |
| ---------- |
| model_path: str |
| Either the path to the model you wish to load or a folder, |
| where each subfolder contains the differents types, i.e.: |
| model_path: |
| | |
| |-- smpl |
| |-- SMPL_FEMALE |
| |-- SMPL_NEUTRAL |
| |-- SMPL_MALE |
| |-- smplh |
| |-- SMPLH_FEMALE |
| |-- SMPLH_MALE |
| |-- smplx |
| |-- SMPLX_FEMALE |
| |-- SMPLX_NEUTRAL |
| |-- SMPLX_MALE |
| |-- mano |
| |-- MANO RIGHT |
| |-- MANO LEFT |
| |
| model_type: str, optional |
| When model_path is a folder, then this parameter specifies the |
| type of model to be loaded |
| **kwargs: dict |
| Keyword arguments |
| |
| Returns |
| ------- |
| body_model: nn.Module |
| The PyTorch module that implements the corresponding body model |
| Raises |
| ------ |
| ValueError: In case the model type is not one of SMPL, SMPLH, |
| SMPLX, MANO or FLAME |
| ''' |
|
|
| |
| if osp.isdir(model_path): |
| model_path = os.path.join(model_path, model_type) |
| else: |
| model_type = osp.basename(model_path).split('_')[0].lower() |
|
|
| if model_type.lower() == 'smpl': |
| return SMPL(model_path, **kwargs) |
| elif model_type.lower() == 'smplh': |
| return SMPLH(model_path, **kwargs) |
| elif model_type.lower() == 'smplx': |
| return SMPLX(model_path, **kwargs) |
| elif 'mano' in model_type.lower(): |
| return MANO(model_path, **kwargs) |
| elif 'flame' in model_type.lower(): |
| return FLAME(model_path, **kwargs) |
| else: |
| raise ValueError(f'Unknown model type {model_type}, exiting!') |
|
|