| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from typing import Optional, Dict, Union |
| | import os |
| | import os.path as osp |
| | import pickle |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | from collections import namedtuple |
| |
|
| | import logging |
| |
|
| | logging.getLogger("smplx").setLevel(logging.ERROR) |
| |
|
| | from .lbs import (lbs, vertices2landmarks, find_dynamic_lmk_idx_and_bcoords) |
| |
|
| | from .vertex_ids import vertex_ids as VERTEX_IDS |
| | from .utils import (Struct, to_np, to_tensor, Tensor, Array, SMPLOutput, |
| | SMPLHOutput, SMPLXOutput, MANOOutput, FLAMEOutput, |
| | find_joint_kin_chain) |
| | from .vertex_joint_selector import VertexJointSelector |
| |
|
| | ModelOutput = namedtuple('ModelOutput', [ |
| | 'vertices', 'joints', 'full_pose', 'betas', 'global_orient', 'body_pose', |
| | 'expression', 'left_hand_pose', 'right_hand_pose', 'jaw_pose' |
| | ]) |
| | ModelOutput.__new__.__defaults__ = (None, ) * len(ModelOutput._fields) |
| |
|
| |
|
| | class SMPL(nn.Module): |
| |
|
| | NUM_JOINTS = 23 |
| | NUM_BODY_JOINTS = 23 |
| | SHAPE_SPACE_DIM = 300 |
| |
|
| | def __init__(self, |
| | model_path: str, |
| | kid_template_path: str = '', |
| | data_struct: Optional[Struct] = None, |
| | create_betas: bool = True, |
| | betas: Optional[Tensor] = None, |
| | num_betas: int = 10, |
| | create_global_orient: bool = True, |
| | global_orient: Optional[Tensor] = None, |
| | create_body_pose: bool = True, |
| | body_pose: Optional[Tensor] = None, |
| | create_transl: bool = True, |
| | transl: Optional[Tensor] = None, |
| | dtype=torch.float32, |
| | batch_size: int = 1, |
| | joint_mapper=None, |
| | gender: str = 'neutral', |
| | age: str = 'adult', |
| | vertex_ids: Dict[str, int] = None, |
| | v_template: Optional[Union[Tensor, Array]] = None, |
| | v_personal: Optional[Union[Tensor, Array]] = None, |
| | **kwargs) -> None: |
| | ''' SMPL model constructor |
| | |
| | Parameters |
| | ---------- |
| | model_path: str |
| | The path to the folder or to the file where the model |
| | parameters are stored |
| | data_struct: Strct |
| | A struct object. If given, then the parameters of the model are |
| | read from the object. Otherwise, the model tries to read the |
| | parameters from the given `model_path`. (default = None) |
| | create_global_orient: bool, optional |
| | Flag for creating a member variable for the global orientation |
| | of the body. (default = True) |
| | global_orient: torch.tensor, optional, Bx3 |
| | The default value for the global orientation variable. |
| | (default = None) |
| | create_body_pose: bool, optional |
| | Flag for creating a member variable for the pose of the body. |
| | (default = True) |
| | body_pose: torch.tensor, optional, Bx(Body Joints * 3) |
| | The default value for the body pose variable. |
| | (default = None) |
| | num_betas: int, optional |
| | Number of shape components to use |
| | (default = 10). |
| | create_betas: bool, optional |
| | Flag for creating a member variable for the shape space |
| | (default = True). |
| | betas: torch.tensor, optional, Bx10 |
| | The default value for the shape member variable. |
| | (default = None) |
| | create_transl: bool, optional |
| | Flag for creating a member variable for the translation |
| | of the body. (default = True) |
| | transl: torch.tensor, optional, Bx3 |
| | The default value for the transl variable. |
| | (default = None) |
| | dtype: torch.dtype, optional |
| | The data type for the created variables |
| | batch_size: int, optional |
| | The batch size used for creating the member variables |
| | joint_mapper: object, optional |
| | An object that re-maps the joints. Useful if one wants to |
| | re-order the SMPL joints to some other convention (e.g. MSCOCO) |
| | (default = None) |
| | gender: str, optional |
| | Which gender to load |
| | vertex_ids: dict, optional |
| | A dictionary containing the indices of the extra vertices that |
| | will be selected |
| | ''' |
| |
|
| | self.gender = gender |
| | self.age = age |
| |
|
| | if data_struct is None: |
| | if osp.isdir(model_path): |
| | model_fn = 'SMPL_{}.{ext}'.format(gender.upper(), ext='pkl') |
| | smpl_path = os.path.join(model_path, model_fn) |
| | else: |
| | smpl_path = model_path |
| | assert osp.exists(smpl_path), 'Path {} does not exist!'.format( |
| | smpl_path) |
| |
|
| | with open(smpl_path, 'rb') as smpl_file: |
| | data_struct = Struct( |
| | **pickle.load(smpl_file, encoding='latin1')) |
| |
|
| | super(SMPL, self).__init__() |
| | self.batch_size = batch_size |
| | shapedirs = data_struct.shapedirs |
| | if (shapedirs.shape[-1] < self.SHAPE_SPACE_DIM): |
| | |
| | |
| | num_betas = min(num_betas, 10) |
| | else: |
| | num_betas = min(num_betas, self.SHAPE_SPACE_DIM) |
| |
|
| | if self.age == 'kid': |
| | v_template_smil = np.load(kid_template_path) |
| | v_template_smil -= np.mean(v_template_smil, axis=0) |
| | v_template_diff = np.expand_dims(v_template_smil - |
| | data_struct.v_template, |
| | axis=2) |
| | shapedirs = np.concatenate( |
| | (shapedirs[:, :, :num_betas], v_template_diff), axis=2) |
| | num_betas = num_betas + 1 |
| |
|
| | self._num_betas = num_betas |
| | shapedirs = shapedirs[:, :, :num_betas] |
| | |
| | self.register_buffer('shapedirs', |
| | to_tensor(to_np(shapedirs), dtype=dtype)) |
| |
|
| | if vertex_ids is None: |
| | |
| | |
| | vertex_ids = VERTEX_IDS['smplh'] |
| |
|
| | self.dtype = dtype |
| |
|
| | self.joint_mapper = joint_mapper |
| |
|
| | self.vertex_joint_selector = VertexJointSelector(vertex_ids=vertex_ids, |
| | **kwargs) |
| |
|
| | self.faces = data_struct.f |
| | self.register_buffer( |
| | 'faces_tensor', |
| | to_tensor(to_np(self.faces, dtype=np.int64), dtype=torch.long)) |
| |
|
| | if create_betas: |
| | if betas is None: |
| | default_betas = torch.zeros([batch_size, self.num_betas], |
| | dtype=dtype) |
| | else: |
| | if torch.is_tensor(betas): |
| | default_betas = betas.clone().detach() |
| | else: |
| | default_betas = torch.tensor(betas, dtype=dtype) |
| |
|
| | self.register_parameter( |
| | 'betas', nn.Parameter(default_betas, requires_grad=True)) |
| |
|
| | |
| | |
| | |
| | if create_global_orient: |
| | if global_orient is None: |
| | default_global_orient = torch.zeros([batch_size, 3], |
| | dtype=dtype) |
| | else: |
| | if torch.is_tensor(global_orient): |
| | default_global_orient = global_orient.clone().detach() |
| | else: |
| | default_global_orient = torch.tensor(global_orient, |
| | dtype=dtype) |
| |
|
| | global_orient = nn.Parameter(default_global_orient, |
| | requires_grad=True) |
| | self.register_parameter('global_orient', global_orient) |
| |
|
| | if create_body_pose: |
| | if body_pose is None: |
| | default_body_pose = torch.zeros( |
| | [batch_size, self.NUM_BODY_JOINTS * 3], dtype=dtype) |
| | else: |
| | if torch.is_tensor(body_pose): |
| | default_body_pose = body_pose.clone().detach() |
| | else: |
| | default_body_pose = torch.tensor(body_pose, dtype=dtype) |
| | self.register_parameter( |
| | 'body_pose', nn.Parameter(default_body_pose, |
| | requires_grad=True)) |
| |
|
| | if create_transl: |
| | if transl is None: |
| | default_transl = torch.zeros([batch_size, 3], |
| | dtype=dtype, |
| | requires_grad=True) |
| | else: |
| | default_transl = torch.tensor(transl, dtype=dtype) |
| | self.register_parameter( |
| | 'transl', nn.Parameter(default_transl, requires_grad=True)) |
| |
|
| | if v_template is None: |
| | v_template = data_struct.v_template |
| |
|
| | if not torch.is_tensor(v_template): |
| | v_template = to_tensor(to_np(v_template), dtype=dtype) |
| |
|
| | if v_personal is not None: |
| | v_personal = to_tensor(to_np(v_personal), dtype=dtype) |
| | v_template += v_personal |
| |
|
| | |
| | self.register_buffer('v_template', v_template) |
| |
|
| | j_regressor = to_tensor(to_np(data_struct.J_regressor), dtype=dtype) |
| | self.register_buffer('J_regressor', j_regressor) |
| |
|
| | |
| | num_pose_basis = data_struct.posedirs.shape[-1] |
| | |
| | posedirs = np.reshape(data_struct.posedirs, [-1, num_pose_basis]).T |
| | self.register_buffer('posedirs', to_tensor(to_np(posedirs), |
| | dtype=dtype)) |
| |
|
| | |
| | parents = to_tensor(to_np(data_struct.kintree_table[0])).long() |
| | parents[0] = -1 |
| | self.register_buffer('parents', parents) |
| |
|
| | self.register_buffer( |
| | 'lbs_weights', to_tensor(to_np(data_struct.weights), dtype=dtype)) |
| |
|
| | @property |
| | def num_betas(self): |
| | return self._num_betas |
| |
|
| | @property |
| | def num_expression_coeffs(self): |
| | return 0 |
| |
|
| | def create_mean_pose(self, data_struct) -> Tensor: |
| | pass |
| |
|
| | def name(self) -> str: |
| | return 'SMPL' |
| |
|
| | @torch.no_grad() |
| | def reset_params(self, **params_dict) -> None: |
| | for param_name, param in self.named_parameters(): |
| | if param_name in params_dict: |
| | param[:] = torch.tensor(params_dict[param_name]) |
| | else: |
| | param.fill_(0) |
| |
|
| | def get_num_verts(self) -> int: |
| | return self.v_template.shape[0] |
| |
|
| | def get_num_faces(self) -> int: |
| | return self.faces.shape[0] |
| |
|
| | def extra_repr(self) -> str: |
| | msg = [ |
| | f'Gender: {self.gender.upper()}', |
| | f'Number of joints: {self.J_regressor.shape[0]}', |
| | f'Betas: {self.num_betas}', |
| | ] |
| | return '\n'.join(msg) |
| |
|
| | def forward(self, |
| | betas: Optional[Tensor] = None, |
| | body_pose: Optional[Tensor] = None, |
| | global_orient: Optional[Tensor] = None, |
| | transl: Optional[Tensor] = None, |
| | return_verts=True, |
| | return_full_pose: bool = False, |
| | pose2rot: bool = True, |
| | **kwargs) -> SMPLOutput: |
| | ''' Forward pass for the SMPL model |
| | |
| | Parameters |
| | ---------- |
| | global_orient: torch.tensor, optional, shape Bx3 |
| | If given, ignore the member variable and use it as the global |
| | rotation of the body. Useful if someone wishes to predicts this |
| | with an external model. (default=None) |
| | betas: torch.tensor, optional, shape BxN_b |
| | If given, ignore the member variable `betas` and use it |
| | instead. For example, it can used if shape parameters |
| | `betas` are predicted from some external model. |
| | (default=None) |
| | body_pose: torch.tensor, optional, shape Bx(J*3) |
| | If given, ignore the member variable `body_pose` and use it |
| | instead. For example, it can used if someone predicts the |
| | pose of the body joints are predicted from some external model. |
| | It should be a tensor that contains joint rotations in |
| | axis-angle format. (default=None) |
| | transl: torch.tensor, optional, shape Bx3 |
| | If given, ignore the member variable `transl` and use it |
| | instead. For example, it can used if the translation |
| | `transl` is predicted from some external model. |
| | (default=None) |
| | return_verts: bool, optional |
| | Return the vertices. (default=True) |
| | return_full_pose: bool, optional |
| | Returns the full axis-angle pose vector (default=False) |
| | |
| | Returns |
| | ------- |
| | ''' |
| | |
| | |
| | global_orient = (global_orient |
| | if global_orient is not None else self.global_orient) |
| | body_pose = body_pose if body_pose is not None else self.body_pose |
| | betas = betas if betas is not None else self.betas |
| |
|
| | apply_trans = transl is not None or hasattr(self, 'transl') |
| | if transl is None and hasattr(self, 'transl'): |
| | transl = self.transl |
| |
|
| | full_pose = torch.cat([global_orient, body_pose], dim=1) |
| |
|
| | batch_size = max(betas.shape[0], global_orient.shape[0], |
| | body_pose.shape[0]) |
| |
|
| | if betas.shape[0] != batch_size: |
| | num_repeats = int(batch_size / betas.shape[0]) |
| | betas = betas.expand(num_repeats, -1) |
| |
|
| | vertices, joints = lbs(betas, |
| | full_pose, |
| | self.v_template, |
| | self.shapedirs, |
| | self.posedirs, |
| | self.J_regressor, |
| | self.parents, |
| | self.lbs_weights, |
| | pose2rot=pose2rot) |
| |
|
| | joints = self.vertex_joint_selector(vertices, joints) |
| | |
| | if self.joint_mapper is not None: |
| | joints = self.joint_mapper(joints) |
| |
|
| | if apply_trans: |
| | joints += transl.unsqueeze(dim=1) |
| | vertices += transl.unsqueeze(dim=1) |
| |
|
| | output = SMPLOutput(vertices=vertices if return_verts else None, |
| | global_orient=global_orient, |
| | body_pose=body_pose, |
| | joints=joints, |
| | betas=betas, |
| | full_pose=full_pose if return_full_pose else None) |
| |
|
| | return output |
| |
|
| |
|
| | class SMPLLayer(SMPL): |
| |
|
| | def __init__(self, *args, **kwargs) -> None: |
| | |
| | super(SMPLLayer, self).__init__( |
| | create_body_pose=False, |
| | create_betas=False, |
| | create_global_orient=False, |
| | create_transl=False, |
| | *args, |
| | **kwargs, |
| | ) |
| |
|
| | def forward(self, |
| | betas: Optional[Tensor] = None, |
| | body_pose: Optional[Tensor] = None, |
| | global_orient: Optional[Tensor] = None, |
| | transl: Optional[Tensor] = None, |
| | return_verts=True, |
| | return_full_pose: bool = False, |
| | pose2rot: bool = True, |
| | **kwargs) -> SMPLOutput: |
| | ''' Forward pass for the SMPL model |
| | |
| | Parameters |
| | ---------- |
| | global_orient: torch.tensor, optional, shape Bx3x3 |
| | Global rotation of the body. Useful if someone wishes to |
| | predicts this with an external model. It is expected to be in |
| | rotation matrix format. (default=None) |
| | betas: torch.tensor, optional, shape BxN_b |
| | Shape parameters. For example, it can used if shape parameters |
| | `betas` are predicted from some external model. |
| | (default=None) |
| | body_pose: torch.tensor, optional, shape BxJx3x3 |
| | Body pose. For example, it can used if someone predicts the |
| | pose of the body joints are predicted from some external model. |
| | It should be a tensor that contains joint rotations in |
| | rotation matrix format. (default=None) |
| | transl: torch.tensor, optional, shape Bx3 |
| | Translation vector of the body. |
| | For example, it can used if the translation |
| | `transl` is predicted from some external model. |
| | (default=None) |
| | return_verts: bool, optional |
| | Return the vertices. (default=True) |
| | return_full_pose: bool, optional |
| | Returns the full axis-angle pose vector (default=False) |
| | |
| | Returns |
| | ------- |
| | ''' |
| | model_vars = [betas, global_orient, body_pose, transl] |
| | batch_size = 1 |
| | for var in model_vars: |
| | if var is None: |
| | continue |
| | batch_size = max(batch_size, len(var)) |
| | device, dtype = self.shapedirs.device, self.shapedirs.dtype |
| | if global_orient is None: |
| | global_orient = torch.eye(3, device=device, dtype=dtype).view( |
| | 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() |
| | if body_pose is None: |
| | body_pose = torch.eye(3, device=device, dtype=dtype).view( |
| | 1, 1, 3, 3).expand(batch_size, self.NUM_BODY_JOINTS, -1, |
| | -1).contiguous() |
| | if betas is None: |
| | betas = torch.zeros([batch_size, self.num_betas], |
| | dtype=dtype, |
| | device=device) |
| | if transl is None: |
| | transl = torch.zeros([batch_size, 3], dtype=dtype, device=device) |
| | full_pose = torch.cat([ |
| | global_orient.reshape(-1, 1, 3, 3), |
| | body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3) |
| | ], |
| | dim=1) |
| |
|
| | vertices, joints = lbs(betas, |
| | full_pose, |
| | self.v_template, |
| | self.shapedirs, |
| | self.posedirs, |
| | self.J_regressor, |
| | self.parents, |
| | self.lbs_weights, |
| | pose2rot=False) |
| |
|
| | joints = self.vertex_joint_selector(vertices, joints) |
| | |
| | if self.joint_mapper is not None: |
| | joints = self.joint_mapper(joints) |
| |
|
| | if transl is not None: |
| | joints += transl.unsqueeze(dim=1) |
| | vertices += transl.unsqueeze(dim=1) |
| |
|
| | output = SMPLOutput(vertices=vertices if return_verts else None, |
| | global_orient=global_orient, |
| | body_pose=body_pose, |
| | joints=joints, |
| | betas=betas, |
| | full_pose=full_pose if return_full_pose else None) |
| |
|
| | return output |
| |
|
| |
|
| | class SMPLH(SMPL): |
| |
|
| | |
| | NUM_BODY_JOINTS = SMPL.NUM_JOINTS - 2 |
| | NUM_HAND_JOINTS = 15 |
| | NUM_JOINTS = NUM_BODY_JOINTS + 2 * NUM_HAND_JOINTS |
| |
|
| | def __init__(self, |
| | model_path, |
| | kid_template_path: str = '', |
| | data_struct: Optional[Struct] = None, |
| | create_left_hand_pose: bool = True, |
| | left_hand_pose: Optional[Tensor] = None, |
| | create_right_hand_pose: bool = True, |
| | right_hand_pose: Optional[Tensor] = None, |
| | use_pca: bool = True, |
| | num_pca_comps: int = 6, |
| | flat_hand_mean: bool = False, |
| | batch_size: int = 1, |
| | gender: str = 'neutral', |
| | age: str = 'adult', |
| | dtype=torch.float32, |
| | vertex_ids=None, |
| | use_compressed: bool = True, |
| | ext: str = 'pkl', |
| | **kwargs) -> None: |
| | ''' SMPLH model constructor |
| | |
| | Parameters |
| | ---------- |
| | model_path: str |
| | The path to the folder or to the file where the model |
| | parameters are stored |
| | data_struct: Strct |
| | A struct object. If given, then the parameters of the model are |
| | read from the object. Otherwise, the model tries to read the |
| | parameters from the given `model_path`. (default = None) |
| | create_left_hand_pose: bool, optional |
| | Flag for creating a member variable for the pose of the left |
| | hand. (default = True) |
| | left_hand_pose: torch.tensor, optional, BxP |
| | The default value for the left hand pose member variable. |
| | (default = None) |
| | create_right_hand_pose: bool, optional |
| | Flag for creating a member variable for the pose of the right |
| | hand. (default = True) |
| | right_hand_pose: torch.tensor, optional, BxP |
| | The default value for the right hand pose member variable. |
| | (default = None) |
| | num_pca_comps: int, optional |
| | The number of PCA components to use for each hand. |
| | (default = 6) |
| | flat_hand_mean: bool, optional |
| | If False, then the pose of the hand is initialized to False. |
| | batch_size: int, optional |
| | The batch size used for creating the member variables |
| | gender: str, optional |
| | Which gender to load |
| | dtype: torch.dtype, optional |
| | The data type for the created variables |
| | vertex_ids: dict, optional |
| | A dictionary containing the indices of the extra vertices that |
| | will be selected |
| | ''' |
| |
|
| | self.num_pca_comps = num_pca_comps |
| | |
| | |
| | if data_struct is None: |
| | |
| | if osp.isdir(model_path): |
| | model_fn = 'SMPLH_{}.{ext}'.format(gender.upper(), ext=ext) |
| | smplh_path = os.path.join(model_path, model_fn) |
| | else: |
| | smplh_path = model_path |
| | assert osp.exists(smplh_path), 'Path {} does not exist!'.format( |
| | smplh_path) |
| |
|
| | if ext == 'pkl': |
| | with open(smplh_path, 'rb') as smplh_file: |
| | model_data = pickle.load(smplh_file, encoding='latin1') |
| | elif ext == 'npz': |
| | model_data = np.load(smplh_path, allow_pickle=True) |
| | else: |
| | raise ValueError('Unknown extension: {}'.format(ext)) |
| | data_struct = Struct(**model_data) |
| |
|
| | if vertex_ids is None: |
| | vertex_ids = VERTEX_IDS['smplh'] |
| |
|
| | super(SMPLH, self).__init__(model_path=model_path, |
| | kid_template_path=kid_template_path, |
| | data_struct=data_struct, |
| | batch_size=batch_size, |
| | vertex_ids=vertex_ids, |
| | gender=gender, |
| | age=age, |
| | use_compressed=use_compressed, |
| | dtype=dtype, |
| | ext=ext, |
| | **kwargs) |
| |
|
| | self.use_pca = use_pca |
| | self.num_pca_comps = num_pca_comps |
| | self.flat_hand_mean = flat_hand_mean |
| |
|
| | left_hand_components = data_struct.hands_componentsl[:num_pca_comps] |
| | right_hand_components = data_struct.hands_componentsr[:num_pca_comps] |
| |
|
| | self.np_left_hand_components = left_hand_components |
| | self.np_right_hand_components = right_hand_components |
| | if self.use_pca: |
| | self.register_buffer( |
| | 'left_hand_components', |
| | torch.tensor(left_hand_components, dtype=dtype)) |
| | self.register_buffer( |
| | 'right_hand_components', |
| | torch.tensor(right_hand_components, dtype=dtype)) |
| |
|
| | if self.flat_hand_mean: |
| | left_hand_mean = np.zeros_like(data_struct.hands_meanl) |
| | else: |
| | left_hand_mean = data_struct.hands_meanl |
| |
|
| | if self.flat_hand_mean: |
| | right_hand_mean = np.zeros_like(data_struct.hands_meanr) |
| | else: |
| | right_hand_mean = data_struct.hands_meanr |
| |
|
| | self.register_buffer('left_hand_mean', |
| | to_tensor(left_hand_mean, dtype=self.dtype)) |
| | self.register_buffer('right_hand_mean', |
| | to_tensor(right_hand_mean, dtype=self.dtype)) |
| |
|
| | |
| | hand_pose_dim = num_pca_comps if use_pca else 3 * self.NUM_HAND_JOINTS |
| | if create_left_hand_pose: |
| | if left_hand_pose is None: |
| | default_lhand_pose = torch.zeros([batch_size, hand_pose_dim], |
| | dtype=dtype) |
| | else: |
| | default_lhand_pose = torch.tensor(left_hand_pose, dtype=dtype) |
| |
|
| | left_hand_pose_param = nn.Parameter(default_lhand_pose, |
| | requires_grad=True) |
| | self.register_parameter('left_hand_pose', left_hand_pose_param) |
| |
|
| | if create_right_hand_pose: |
| | if right_hand_pose is None: |
| | default_rhand_pose = torch.zeros([batch_size, hand_pose_dim], |
| | dtype=dtype) |
| | else: |
| | default_rhand_pose = torch.tensor(right_hand_pose, dtype=dtype) |
| |
|
| | right_hand_pose_param = nn.Parameter(default_rhand_pose, |
| | requires_grad=True) |
| | self.register_parameter('right_hand_pose', right_hand_pose_param) |
| |
|
| | |
| | pose_mean_tensor = self.create_mean_pose(data_struct, |
| | flat_hand_mean=flat_hand_mean) |
| | if not torch.is_tensor(pose_mean_tensor): |
| | pose_mean_tensor = torch.tensor(pose_mean_tensor, dtype=dtype) |
| | self.register_buffer('pose_mean', pose_mean_tensor) |
| |
|
| | def create_mean_pose(self, data_struct, flat_hand_mean=False): |
| | |
| | |
| | global_orient_mean = torch.zeros([3], dtype=self.dtype) |
| | body_pose_mean = torch.zeros([self.NUM_BODY_JOINTS * 3], |
| | dtype=self.dtype) |
| |
|
| | pose_mean = torch.cat([ |
| | global_orient_mean, body_pose_mean, self.left_hand_mean, |
| | self.right_hand_mean |
| | ], |
| | dim=0) |
| | return pose_mean |
| |
|
| | def name(self) -> str: |
| | return 'SMPL+H' |
| |
|
| | def extra_repr(self): |
| | msg = super(SMPLH, self).extra_repr() |
| | msg = [msg] |
| | if self.use_pca: |
| | msg.append(f'Number of PCA components: {self.num_pca_comps}') |
| | msg.append(f'Flat hand mean: {self.flat_hand_mean}') |
| | return '\n'.join(msg) |
| |
|
| | def forward(self, |
| | betas: Optional[Tensor] = None, |
| | global_orient: Optional[Tensor] = None, |
| | body_pose: Optional[Tensor] = None, |
| | left_hand_pose: Optional[Tensor] = None, |
| | right_hand_pose: Optional[Tensor] = None, |
| | transl: Optional[Tensor] = None, |
| | return_verts: bool = True, |
| | return_full_pose: bool = False, |
| | pose2rot: bool = True, |
| | **kwargs) -> SMPLHOutput: |
| | ''' |
| | ''' |
| |
|
| | |
| | |
| | global_orient = (global_orient |
| | if global_orient is not None else self.global_orient) |
| | body_pose = body_pose if body_pose is not None else self.body_pose |
| | betas = betas if betas is not None else self.betas |
| | left_hand_pose = (left_hand_pose if left_hand_pose is not None else |
| | self.left_hand_pose) |
| | right_hand_pose = (right_hand_pose if right_hand_pose is not None else |
| | self.right_hand_pose) |
| |
|
| | apply_trans = transl is not None or hasattr(self, 'transl') |
| | if transl is None: |
| | if hasattr(self, 'transl'): |
| | transl = self.transl |
| |
|
| | if self.use_pca: |
| | left_hand_pose = torch.einsum( |
| | 'bi,ij->bj', [left_hand_pose, self.left_hand_components]) |
| | right_hand_pose = torch.einsum( |
| | 'bi,ij->bj', [right_hand_pose, self.right_hand_components]) |
| |
|
| | full_pose = torch.cat( |
| | [global_orient, body_pose, left_hand_pose, right_hand_pose], dim=1) |
| |
|
| | full_pose += self.pose_mean |
| |
|
| | vertices, joints = lbs(betas, |
| | full_pose, |
| | self.v_template, |
| | self.shapedirs, |
| | self.posedirs, |
| | self.J_regressor, |
| | self.parents, |
| | self.lbs_weights, |
| | pose2rot=pose2rot) |
| |
|
| | |
| | joints = self.vertex_joint_selector(vertices, joints) |
| | if self.joint_mapper is not None: |
| | joints = self.joint_mapper(joints) |
| |
|
| | if apply_trans: |
| | joints += transl.unsqueeze(dim=1) |
| | vertices += transl.unsqueeze(dim=1) |
| |
|
| | output = SMPLHOutput(vertices=vertices if return_verts else None, |
| | joints=joints, |
| | betas=betas, |
| | global_orient=global_orient, |
| | body_pose=body_pose, |
| | left_hand_pose=left_hand_pose, |
| | right_hand_pose=right_hand_pose, |
| | full_pose=full_pose if return_full_pose else None) |
| |
|
| | return output |
| |
|
| |
|
| | class SMPLHLayer(SMPLH): |
| |
|
| | def __init__(self, *args, **kwargs) -> None: |
| | ''' SMPL+H as a layer model constructor |
| | ''' |
| | super(SMPLHLayer, self).__init__(create_global_orient=False, |
| | create_body_pose=False, |
| | create_left_hand_pose=False, |
| | create_right_hand_pose=False, |
| | create_betas=False, |
| | create_transl=False, |
| | *args, |
| | **kwargs) |
| |
|
| | def forward(self, |
| | betas: Optional[Tensor] = None, |
| | global_orient: Optional[Tensor] = None, |
| | body_pose: Optional[Tensor] = None, |
| | left_hand_pose: Optional[Tensor] = None, |
| | right_hand_pose: Optional[Tensor] = None, |
| | transl: Optional[Tensor] = None, |
| | return_verts: bool = True, |
| | return_full_pose: bool = False, |
| | pose2rot: bool = True, |
| | **kwargs) -> SMPLHOutput: |
| | ''' Forward pass for the SMPL+H model |
| | |
| | Parameters |
| | ---------- |
| | global_orient: torch.tensor, optional, shape Bx3x3 |
| | Global rotation of the body. Useful if someone wishes to |
| | predicts this with an external model. It is expected to be in |
| | rotation matrix format. (default=None) |
| | betas: torch.tensor, optional, shape BxN_b |
| | Shape parameters. For example, it can used if shape parameters |
| | `betas` are predicted from some external model. |
| | (default=None) |
| | body_pose: torch.tensor, optional, shape BxJx3x3 |
| | If given, ignore the member variable `body_pose` and use it |
| | instead. For example, it can used if someone predicts the |
| | pose of the body joints are predicted from some external model. |
| | It should be a tensor that contains joint rotations in |
| | rotation matrix format. (default=None) |
| | left_hand_pose: torch.tensor, optional, shape Bx15x3x3 |
| | If given, contains the pose of the left hand. |
| | It should be a tensor that contains joint rotations in |
| | rotation matrix format. (default=None) |
| | right_hand_pose: torch.tensor, optional, shape Bx15x3x3 |
| | If given, contains the pose of the right hand. |
| | It should be a tensor that contains joint rotations in |
| | rotation matrix format. (default=None) |
| | transl: torch.tensor, optional, shape Bx3 |
| | Translation vector of the body. |
| | For example, it can used if the translation |
| | `transl` is predicted from some external model. |
| | (default=None) |
| | return_verts: bool, optional |
| | Return the vertices. (default=True) |
| | return_full_pose: bool, optional |
| | Returns the full axis-angle pose vector (default=False) |
| | |
| | Returns |
| | ------- |
| | ''' |
| | model_vars = [ |
| | betas, global_orient, body_pose, transl, left_hand_pose, |
| | right_hand_pose |
| | ] |
| | batch_size = 1 |
| | for var in model_vars: |
| | if var is None: |
| | continue |
| | batch_size = max(batch_size, len(var)) |
| | device, dtype = self.shapedirs.device, self.shapedirs.dtype |
| | if global_orient is None: |
| | global_orient = torch.eye(3, device=device, dtype=dtype).view( |
| | 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() |
| | if body_pose is None: |
| | body_pose = torch.eye(3, device=device, dtype=dtype).view( |
| | 1, 1, 3, 3).expand(batch_size, 21, -1, -1).contiguous() |
| | if left_hand_pose is None: |
| | left_hand_pose = torch.eye(3, device=device, dtype=dtype).view( |
| | 1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous() |
| | if right_hand_pose is None: |
| | right_hand_pose = torch.eye(3, device=device, dtype=dtype).view( |
| | 1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous() |
| | if betas is None: |
| | betas = torch.zeros([batch_size, self.num_betas], |
| | dtype=dtype, |
| | device=device) |
| | if transl is None: |
| | transl = torch.zeros([batch_size, 3], dtype=dtype, device=device) |
| |
|
| | |
| | full_pose = torch.cat([ |
| | global_orient.reshape(-1, 1, 3, 3), |
| | body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3), |
| | left_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3), |
| | right_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3) |
| | ], |
| | dim=1) |
| |
|
| | vertices, joints = lbs(betas, |
| | full_pose, |
| | self.v_template, |
| | self.shapedirs, |
| | self.posedirs, |
| | self.J_regressor, |
| | self.parents, |
| | self.lbs_weights, |
| | pose2rot=False) |
| |
|
| | |
| | joints = self.vertex_joint_selector(vertices, joints) |
| | if self.joint_mapper is not None: |
| | joints = self.joint_mapper(joints) |
| |
|
| | if transl is not None: |
| | joints += transl.unsqueeze(dim=1) |
| | vertices += transl.unsqueeze(dim=1) |
| |
|
| | output = SMPLHOutput(vertices=vertices if return_verts else None, |
| | joints=joints, |
| | betas=betas, |
| | global_orient=global_orient, |
| | body_pose=body_pose, |
| | left_hand_pose=left_hand_pose, |
| | right_hand_pose=right_hand_pose, |
| | full_pose=full_pose if return_full_pose else None) |
| |
|
| | return output |
| |
|
| |
|
| | class SMPLX(SMPLH): |
| | ''' |
| | SMPL-X (SMPL eXpressive) is a unified body model, with shape parameters |
| | trained jointly for the face, hands and body. |
| | SMPL-X uses standard vertex based linear blend skinning with learned |
| | corrective blend shapes, has N=10475 vertices and K=54 joints, |
| | which includes joints for the neck, jaw, eyeballs and fingers. |
| | ''' |
| |
|
| | NUM_BODY_JOINTS = SMPLH.NUM_BODY_JOINTS |
| | NUM_HAND_JOINTS = 15 |
| | NUM_FACE_JOINTS = 3 |
| | NUM_JOINTS = NUM_BODY_JOINTS + 2 * NUM_HAND_JOINTS + NUM_FACE_JOINTS |
| | EXPRESSION_SPACE_DIM = 100 |
| | NECK_IDX = 12 |
| |
|
| | def __init__(self, |
| | model_path: str, |
| | kid_template_path: str = '', |
| | num_expression_coeffs: int = 10, |
| | create_expression: bool = True, |
| | expression: Optional[Tensor] = None, |
| | create_jaw_pose: bool = True, |
| | jaw_pose: Optional[Tensor] = None, |
| | create_leye_pose: bool = True, |
| | leye_pose: Optional[Tensor] = None, |
| | create_reye_pose=True, |
| | reye_pose: Optional[Tensor] = None, |
| | use_face_contour: bool = False, |
| | batch_size: int = 1, |
| | gender: str = 'neutral', |
| | age: str = 'adult', |
| | dtype=torch.float32, |
| | ext: str = 'npz', |
| | **kwargs) -> None: |
| | ''' SMPLX model constructor |
| | |
| | Parameters |
| | ---------- |
| | model_path: str |
| | The path to the folder or to the file where the model |
| | parameters are stored |
| | num_expression_coeffs: int, optional |
| | Number of expression components to use |
| | (default = 10). |
| | create_expression: bool, optional |
| | Flag for creating a member variable for the expression space |
| | (default = True). |
| | expression: torch.tensor, optional, Bx10 |
| | The default value for the expression member variable. |
| | (default = None) |
| | create_jaw_pose: bool, optional |
| | Flag for creating a member variable for the jaw pose. |
| | (default = False) |
| | jaw_pose: torch.tensor, optional, Bx3 |
| | The default value for the jaw pose variable. |
| | (default = None) |
| | create_leye_pose: bool, optional |
| | Flag for creating a member variable for the left eye pose. |
| | (default = False) |
| | leye_pose: torch.tensor, optional, Bx10 |
| | The default value for the left eye pose variable. |
| | (default = None) |
| | create_reye_pose: bool, optional |
| | Flag for creating a member variable for the right eye pose. |
| | (default = False) |
| | reye_pose: torch.tensor, optional, Bx10 |
| | The default value for the right eye pose variable. |
| | (default = None) |
| | use_face_contour: bool, optional |
| | Whether to compute the keypoints that form the facial contour |
| | batch_size: int, optional |
| | The batch size used for creating the member variables |
| | gender: str, optional |
| | Which gender to load |
| | dtype: torch.dtype |
| | The data type for the created variables |
| | ''' |
| |
|
| | |
| | if osp.isdir(model_path): |
| | model_fn = 'SMPLX_{}.{ext}'.format(gender.upper(), ext=ext) |
| | smplx_path = os.path.join(model_path, model_fn) |
| | else: |
| | smplx_path = model_path |
| | assert osp.exists(smplx_path), 'Path {} does not exist!'.format( |
| | smplx_path) |
| |
|
| | if ext == 'pkl': |
| | with open(smplx_path, 'rb') as smplx_file: |
| | model_data = pickle.load(smplx_file, encoding='latin1') |
| | elif ext == 'npz': |
| | model_data = np.load(smplx_path, allow_pickle=True) |
| | else: |
| | raise ValueError('Unknown extension: {}'.format(ext)) |
| |
|
| | data_struct = Struct(**model_data) |
| |
|
| | super(SMPLX, self).__init__(model_path=model_path, |
| | kid_template_path=kid_template_path, |
| | data_struct=data_struct, |
| | dtype=dtype, |
| | batch_size=batch_size, |
| | vertex_ids=VERTEX_IDS['smplx'], |
| | gender=gender, |
| | age=age, |
| | ext=ext, |
| | **kwargs) |
| |
|
| | lmk_faces_idx = data_struct.lmk_faces_idx |
| | self.register_buffer('lmk_faces_idx', |
| | torch.tensor(lmk_faces_idx, dtype=torch.long)) |
| | lmk_bary_coords = data_struct.lmk_bary_coords |
| | self.register_buffer('lmk_bary_coords', |
| | torch.tensor(lmk_bary_coords, dtype=dtype)) |
| |
|
| | self.use_face_contour = use_face_contour |
| | if self.use_face_contour: |
| | dynamic_lmk_faces_idx = data_struct.dynamic_lmk_faces_idx |
| | dynamic_lmk_faces_idx = torch.tensor(dynamic_lmk_faces_idx, |
| | dtype=torch.long) |
| | self.register_buffer('dynamic_lmk_faces_idx', |
| | dynamic_lmk_faces_idx) |
| |
|
| | dynamic_lmk_bary_coords = data_struct.dynamic_lmk_bary_coords |
| | dynamic_lmk_bary_coords = torch.tensor(dynamic_lmk_bary_coords, |
| | dtype=dtype) |
| | self.register_buffer('dynamic_lmk_bary_coords', |
| | dynamic_lmk_bary_coords) |
| |
|
| | neck_kin_chain = find_joint_kin_chain(self.NECK_IDX, self.parents) |
| | self.register_buffer( |
| | 'neck_kin_chain', torch.tensor(neck_kin_chain, |
| | dtype=torch.long)) |
| |
|
| | if create_jaw_pose: |
| | if jaw_pose is None: |
| | default_jaw_pose = torch.zeros([batch_size, 3], dtype=dtype) |
| | else: |
| | default_jaw_pose = torch.tensor(jaw_pose, dtype=dtype) |
| | jaw_pose_param = nn.Parameter(default_jaw_pose, requires_grad=True) |
| | self.register_parameter('jaw_pose', jaw_pose_param) |
| |
|
| | if create_leye_pose: |
| | if leye_pose is None: |
| | default_leye_pose = torch.zeros([batch_size, 3], dtype=dtype) |
| | else: |
| | default_leye_pose = torch.tensor(leye_pose, dtype=dtype) |
| | leye_pose_param = nn.Parameter(default_leye_pose, |
| | requires_grad=True) |
| | self.register_parameter('leye_pose', leye_pose_param) |
| |
|
| | if create_reye_pose: |
| | if reye_pose is None: |
| | default_reye_pose = torch.zeros([batch_size, 3], dtype=dtype) |
| | else: |
| | default_reye_pose = torch.tensor(reye_pose, dtype=dtype) |
| | reye_pose_param = nn.Parameter(default_reye_pose, |
| | requires_grad=True) |
| | self.register_parameter('reye_pose', reye_pose_param) |
| |
|
| | shapedirs = data_struct.shapedirs |
| | if len(shapedirs.shape) < 3: |
| | shapedirs = shapedirs[:, :, None] |
| | if (shapedirs.shape[-1] < |
| | self.SHAPE_SPACE_DIM + self.EXPRESSION_SPACE_DIM): |
| | |
| | |
| | expr_start_idx = 10 |
| | expr_end_idx = 20 |
| | num_expression_coeffs = min(num_expression_coeffs, 10) |
| | else: |
| | expr_start_idx = self.SHAPE_SPACE_DIM |
| | expr_end_idx = self.SHAPE_SPACE_DIM + num_expression_coeffs |
| | num_expression_coeffs = min(num_expression_coeffs, |
| | self.EXPRESSION_SPACE_DIM) |
| |
|
| | self._num_expression_coeffs = num_expression_coeffs |
| |
|
| | expr_dirs = shapedirs[:, :, expr_start_idx:expr_end_idx] |
| | self.register_buffer('expr_dirs', |
| | to_tensor(to_np(expr_dirs), dtype=dtype)) |
| |
|
| | if create_expression: |
| | if expression is None: |
| | default_expression = torch.zeros( |
| | [batch_size, self.num_expression_coeffs], dtype=dtype) |
| | else: |
| | default_expression = torch.tensor(expression, dtype=dtype) |
| | expression_param = nn.Parameter(default_expression, |
| | requires_grad=True) |
| | self.register_parameter('expression', expression_param) |
| |
|
| | def name(self) -> str: |
| | return 'SMPL-X' |
| |
|
| | @property |
| | def num_expression_coeffs(self): |
| | return self._num_expression_coeffs |
| |
|
| | def create_mean_pose(self, data_struct, flat_hand_mean=False): |
| | |
| | |
| | global_orient_mean = torch.zeros([3], dtype=self.dtype) |
| | body_pose_mean = torch.zeros([self.NUM_BODY_JOINTS * 3], |
| | dtype=self.dtype) |
| | jaw_pose_mean = torch.zeros([3], dtype=self.dtype) |
| | leye_pose_mean = torch.zeros([3], dtype=self.dtype) |
| | reye_pose_mean = torch.zeros([3], dtype=self.dtype) |
| |
|
| | pose_mean = np.concatenate([ |
| | global_orient_mean, body_pose_mean, jaw_pose_mean, leye_pose_mean, |
| | reye_pose_mean, self.left_hand_mean, self.right_hand_mean |
| | ], |
| | axis=0) |
| |
|
| | return pose_mean |
| |
|
| | def extra_repr(self): |
| | msg = super(SMPLX, self).extra_repr() |
| | msg = [ |
| | msg, |
| | f'Number of Expression Coefficients: {self.num_expression_coeffs}' |
| | ] |
| | return '\n'.join(msg) |
| |
|
| | def forward(self, |
| | betas: Optional[Tensor] = None, |
| | global_orient: Optional[Tensor] = None, |
| | body_pose: Optional[Tensor] = None, |
| | left_hand_pose: Optional[Tensor] = None, |
| | right_hand_pose: Optional[Tensor] = None, |
| | transl: Optional[Tensor] = None, |
| | expression: Optional[Tensor] = None, |
| | jaw_pose: Optional[Tensor] = None, |
| | leye_pose: Optional[Tensor] = None, |
| | reye_pose: Optional[Tensor] = None, |
| | return_verts: bool = True, |
| | return_full_pose: bool = False, |
| | pose2rot: bool = True, |
| | return_joint_transformation: bool = False, |
| | return_vertex_transformation: bool = False, |
| | **kwargs) -> SMPLXOutput: |
| | ''' |
| | Forward pass for the SMPLX model |
| | |
| | Parameters |
| | ---------- |
| | global_orient: torch.tensor, optional, shape Bx3 |
| | If given, ignore the member variable and use it as the global |
| | rotation of the body. Useful if someone wishes to predicts this |
| | with an external model. (default=None) |
| | betas: torch.tensor, optional, shape BxN_b |
| | If given, ignore the member variable `betas` and use it |
| | instead. For example, it can used if shape parameters |
| | `betas` are predicted from some external model. |
| | (default=None) |
| | expression: torch.tensor, optional, shape BxN_e |
| | If given, ignore the member variable `expression` and use it |
| | instead. For example, it can used if expression parameters |
| | `expression` are predicted from some external model. |
| | body_pose: torch.tensor, optional, shape Bx(J*3) |
| | If given, ignore the member variable `body_pose` and use it |
| | instead. For example, it can used if someone predicts the |
| | pose of the body joints are predicted from some external model. |
| | It should be a tensor that contains joint rotations in |
| | axis-angle format. (default=None) |
| | left_hand_pose: torch.tensor, optional, shape BxP |
| | If given, ignore the member variable `left_hand_pose` and |
| | use this instead. It should either contain PCA coefficients or |
| | joint rotations in axis-angle format. |
| | right_hand_pose: torch.tensor, optional, shape BxP |
| | If given, ignore the member variable `right_hand_pose` and |
| | use this instead. It should either contain PCA coefficients or |
| | joint rotations in axis-angle format. |
| | jaw_pose: torch.tensor, optional, shape Bx3 |
| | If given, ignore the member variable `jaw_pose` and |
| | use this instead. It should either joint rotations in |
| | axis-angle format. |
| | transl: torch.tensor, optional, shape Bx3 |
| | If given, ignore the member variable `transl` and use it |
| | instead. For example, it can used if the translation |
| | `transl` is predicted from some external model. |
| | (default=None) |
| | return_verts: bool, optional |
| | Return the vertices. (default=True) |
| | return_full_pose: bool, optional |
| | Returns the full axis-angle pose vector (default=False) |
| | |
| | Returns |
| | ------- |
| | output: ModelOutput |
| | A named tuple of type `ModelOutput` |
| | ''' |
| |
|
| | |
| | |
| | global_orient = (global_orient |
| | if global_orient is not None else self.global_orient) |
| | body_pose = body_pose if body_pose is not None else self.body_pose |
| | betas = betas if betas is not None else self.betas |
| |
|
| | left_hand_pose = (left_hand_pose if left_hand_pose is not None else |
| | self.left_hand_pose) |
| | right_hand_pose = (right_hand_pose if right_hand_pose is not None else |
| | self.right_hand_pose) |
| | jaw_pose = jaw_pose if jaw_pose is not None else self.jaw_pose |
| | leye_pose = leye_pose if leye_pose is not None else self.leye_pose |
| | reye_pose = reye_pose if reye_pose is not None else self.reye_pose |
| | expression = expression if expression is not None else self.expression |
| |
|
| | apply_trans = transl is not None or hasattr(self, 'transl') |
| | if transl is None: |
| | if hasattr(self, 'transl'): |
| | transl = self.transl |
| |
|
| | if self.use_pca: |
| | left_hand_pose = torch.einsum( |
| | 'bi,ij->bj', [left_hand_pose, self.left_hand_components]) |
| | right_hand_pose = torch.einsum( |
| | 'bi,ij->bj', [right_hand_pose, self.right_hand_components]) |
| |
|
| | full_pose = torch.cat([ |
| | global_orient, body_pose, jaw_pose, leye_pose, reye_pose, |
| | left_hand_pose, right_hand_pose |
| | ], |
| | dim=1) |
| |
|
| | |
| | |
| | full_pose += self.pose_mean |
| |
|
| | batch_size = max(betas.shape[0], global_orient.shape[0], |
| | body_pose.shape[0]) |
| | |
| | scale = int(batch_size / betas.shape[0]) |
| | if scale > 1: |
| | betas = betas.expand(scale, -1) |
| | shape_components = torch.cat([betas, expression], dim=-1) |
| |
|
| | shapedirs = torch.cat([self.shapedirs, self.expr_dirs], dim=-1) |
| |
|
| | if return_joint_transformation or return_vertex_transformation: |
| | vertices, joints, joint_transformation, vertex_transformation = lbs( |
| | shape_components, |
| | full_pose, |
| | self.v_template, |
| | shapedirs, |
| | self.posedirs, |
| | self.J_regressor, |
| | self.parents, |
| | self.lbs_weights, |
| | pose2rot=pose2rot, |
| | return_transformation=True) |
| | else: |
| | vertices, joints = lbs( |
| | shape_components, |
| | full_pose, |
| | self.v_template, |
| | shapedirs, |
| | self.posedirs, |
| | self.J_regressor, |
| | self.parents, |
| | self.lbs_weights, |
| | pose2rot=pose2rot, |
| | ) |
| |
|
| | lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).expand( |
| | batch_size, -1).contiguous() |
| | lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).repeat( |
| | self.batch_size, 1, 1) |
| | if self.use_face_contour: |
| | lmk_idx_and_bcoords = find_dynamic_lmk_idx_and_bcoords( |
| | vertices, |
| | full_pose, |
| | self.dynamic_lmk_faces_idx, |
| | self.dynamic_lmk_bary_coords, |
| | self.neck_kin_chain, |
| | pose2rot=True, |
| | ) |
| | dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords |
| |
|
| | lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1) |
| | lmk_bary_coords = torch.cat([ |
| | lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords |
| | ], 1) |
| |
|
| | landmarks = vertices2landmarks(vertices, self.faces_tensor, |
| | lmk_faces_idx, lmk_bary_coords) |
| |
|
| | |
| | joints = self.vertex_joint_selector(vertices, joints) |
| | |
| | joints = torch.cat([joints, landmarks], dim=1) |
| | |
| |
|
| | if self.joint_mapper is not None: |
| | joints = self.joint_mapper(joints=joints, vertices=vertices) |
| |
|
| | if apply_trans: |
| | joints += transl.unsqueeze(dim=1) |
| | vertices += transl.unsqueeze(dim=1) |
| |
|
| | output = SMPLXOutput(vertices=vertices if return_verts else None, |
| | joints=joints, |
| | betas=betas, |
| | expression=expression, |
| | global_orient=global_orient, |
| | body_pose=body_pose, |
| | left_hand_pose=left_hand_pose, |
| | right_hand_pose=right_hand_pose, |
| | jaw_pose=jaw_pose, |
| | full_pose=full_pose if return_full_pose else None, |
| | joint_transformation=joint_transformation |
| | if return_joint_transformation else None, |
| | vertex_transformation=vertex_transformation |
| | if return_vertex_transformation else None) |
| | return output |
| |
|
| |
|
| | class SMPLXLayer(SMPLX): |
| |
|
| | def __init__(self, *args, **kwargs) -> None: |
| | |
| | super(SMPLXLayer, self).__init__( |
| | create_global_orient=False, |
| | create_body_pose=False, |
| | create_left_hand_pose=False, |
| | create_right_hand_pose=False, |
| | create_jaw_pose=False, |
| | create_leye_pose=False, |
| | create_reye_pose=False, |
| | create_betas=False, |
| | create_expression=False, |
| | create_transl=False, |
| | *args, |
| | **kwargs, |
| | ) |
| |
|
| | def forward(self, |
| | betas: Optional[Tensor] = None, |
| | global_orient: Optional[Tensor] = None, |
| | body_pose: Optional[Tensor] = None, |
| | left_hand_pose: Optional[Tensor] = None, |
| | right_hand_pose: Optional[Tensor] = None, |
| | transl: Optional[Tensor] = None, |
| | expression: Optional[Tensor] = None, |
| | jaw_pose: Optional[Tensor] = None, |
| | leye_pose: Optional[Tensor] = None, |
| | reye_pose: Optional[Tensor] = None, |
| | return_verts: bool = True, |
| | return_full_pose: bool = False, |
| | **kwargs) -> SMPLXOutput: |
| | ''' |
| | Forward pass for the SMPLX model |
| | |
| | Parameters |
| | ---------- |
| | global_orient: torch.tensor, optional, shape Bx3x3 |
| | If given, ignore the member variable and use it as the global |
| | rotation of the body. Useful if someone wishes to predicts this |
| | with an external model. It is expected to be in rotation matrix |
| | format. (default=None) |
| | betas: torch.tensor, optional, shape BxN_b |
| | If given, ignore the member variable `betas` and use it |
| | instead. For example, it can used if shape parameters |
| | `betas` are predicted from some external model. |
| | (default=None) |
| | expression: torch.tensor, optional, shape BxN_e |
| | Expression coefficients. |
| | For example, it can used if expression parameters |
| | `expression` are predicted from some external model. |
| | body_pose: torch.tensor, optional, shape BxJx3x3 |
| | If given, ignore the member variable `body_pose` and use it |
| | instead. For example, it can used if someone predicts the |
| | pose of the body joints are predicted from some external model. |
| | It should be a tensor that contains joint rotations in |
| | rotation matrix format. (default=None) |
| | left_hand_pose: torch.tensor, optional, shape Bx15x3x3 |
| | If given, contains the pose of the left hand. |
| | It should be a tensor that contains joint rotations in |
| | rotation matrix format. (default=None) |
| | right_hand_pose: torch.tensor, optional, shape Bx15x3x3 |
| | If given, contains the pose of the right hand. |
| | It should be a tensor that contains joint rotations in |
| | rotation matrix format. (default=None) |
| | jaw_pose: torch.tensor, optional, shape Bx3x3 |
| | Jaw pose. It should either joint rotations in |
| | rotation matrix format. |
| | transl: torch.tensor, optional, shape Bx3 |
| | Translation vector of the body. |
| | For example, it can used if the translation |
| | `transl` is predicted from some external model. |
| | (default=None) |
| | return_verts: bool, optional |
| | Return the vertices. (default=True) |
| | return_full_pose: bool, optional |
| | Returns the full pose vector (default=False) |
| | Returns |
| | ------- |
| | output: ModelOutput |
| | A data class that contains the posed vertices and joints |
| | ''' |
| | device, dtype = self.shapedirs.device, self.shapedirs.dtype |
| |
|
| | model_vars = [ |
| | betas, global_orient, body_pose, transl, expression, |
| | left_hand_pose, right_hand_pose, jaw_pose |
| | ] |
| | batch_size = 1 |
| | for var in model_vars: |
| | if var is None: |
| | continue |
| | batch_size = max(batch_size, len(var)) |
| |
|
| | if global_orient is None: |
| | global_orient = torch.eye(3, device=device, dtype=dtype).view( |
| | 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() |
| | if body_pose is None: |
| | body_pose = torch.eye(3, device=device, dtype=dtype).view( |
| | 1, 1, 3, 3).expand(batch_size, self.NUM_BODY_JOINTS, -1, |
| | -1).contiguous() |
| | if left_hand_pose is None: |
| | left_hand_pose = torch.eye(3, device=device, dtype=dtype).view( |
| | 1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous() |
| | if right_hand_pose is None: |
| | right_hand_pose = torch.eye(3, device=device, dtype=dtype).view( |
| | 1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous() |
| | if jaw_pose is None: |
| | jaw_pose = torch.eye(3, device=device, dtype=dtype).view( |
| | 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() |
| | if leye_pose is None: |
| | leye_pose = torch.eye(3, device=device, dtype=dtype).view( |
| | 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() |
| | if reye_pose is None: |
| | reye_pose = torch.eye(3, device=device, dtype=dtype).view( |
| | 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() |
| | if expression is None: |
| | expression = torch.zeros([batch_size, self.num_expression_coeffs], |
| | dtype=dtype, |
| | device=device) |
| | if betas is None: |
| | betas = torch.zeros([batch_size, self.num_betas], |
| | dtype=dtype, |
| | device=device) |
| | if transl is None: |
| | transl = torch.zeros([batch_size, 3], dtype=dtype, device=device) |
| |
|
| | |
| | full_pose = torch.cat([ |
| | global_orient.reshape(-1, 1, 3, 3), |
| | body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3), |
| | jaw_pose.reshape(-1, 1, 3, 3), |
| | leye_pose.reshape(-1, 1, 3, 3), |
| | reye_pose.reshape(-1, 1, 3, 3), |
| | left_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3), |
| | right_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3) |
| | ], |
| | dim=1) |
| | shape_components = torch.cat([betas, expression], dim=-1) |
| |
|
| | shapedirs = torch.cat([self.shapedirs, self.expr_dirs], dim=-1) |
| |
|
| | vertices, joints = lbs( |
| | shape_components, |
| | full_pose, |
| | self.v_template, |
| | shapedirs, |
| | self.posedirs, |
| | self.J_regressor, |
| | self.parents, |
| | self.lbs_weights, |
| | pose2rot=False, |
| | ) |
| |
|
| | lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).expand( |
| | batch_size, -1).contiguous() |
| | lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).repeat( |
| | batch_size, 1, 1) |
| | if self.use_face_contour: |
| | lmk_idx_and_bcoords = find_dynamic_lmk_idx_and_bcoords( |
| | vertices, |
| | full_pose, |
| | self.dynamic_lmk_faces_idx, |
| | self.dynamic_lmk_bary_coords, |
| | self.neck_kin_chain, |
| | pose2rot=False, |
| | ) |
| | dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords |
| |
|
| | lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1) |
| | lmk_bary_coords = torch.cat([ |
| | lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords |
| | ], 1) |
| |
|
| | landmarks = vertices2landmarks(vertices, self.faces_tensor, |
| | lmk_faces_idx, lmk_bary_coords) |
| |
|
| | |
| | joints = self.vertex_joint_selector(vertices, joints) |
| | |
| | joints = torch.cat([joints, landmarks], dim=1) |
| | |
| |
|
| | if self.joint_mapper is not None: |
| | joints = self.joint_mapper(joints=joints, vertices=vertices) |
| |
|
| | if transl is not None: |
| | joints += transl.unsqueeze(dim=1) |
| | vertices += transl.unsqueeze(dim=1) |
| |
|
| | output = SMPLXOutput(vertices=vertices if return_verts else None, |
| | joints=joints, |
| | betas=betas, |
| | expression=expression, |
| | global_orient=global_orient, |
| | body_pose=body_pose, |
| | left_hand_pose=left_hand_pose, |
| | right_hand_pose=right_hand_pose, |
| | jaw_pose=jaw_pose, |
| | transl=transl, |
| | full_pose=full_pose if return_full_pose else None) |
| | return output |
| |
|
| |
|
| | class MANO(SMPL): |
| | |
| | NUM_BODY_JOINTS = 1 |
| | NUM_HAND_JOINTS = 15 |
| | NUM_JOINTS = NUM_BODY_JOINTS + NUM_HAND_JOINTS |
| |
|
| | def __init__(self, |
| | model_path: str, |
| | is_rhand: bool = True, |
| | data_struct: Optional[Struct] = None, |
| | create_hand_pose: bool = True, |
| | hand_pose: Optional[Tensor] = None, |
| | use_pca: bool = True, |
| | num_pca_comps: int = 6, |
| | flat_hand_mean: bool = False, |
| | batch_size: int = 1, |
| | dtype=torch.float32, |
| | vertex_ids=None, |
| | use_compressed: bool = True, |
| | ext: str = 'pkl', |
| | **kwargs) -> None: |
| | ''' MANO model constructor |
| | |
| | Parameters |
| | ---------- |
| | model_path: str |
| | The path to the folder or to the file where the model |
| | parameters are stored |
| | data_struct: Strct |
| | A struct object. If given, then the parameters of the model are |
| | read from the object. Otherwise, the model tries to read the |
| | parameters from the given `model_path`. (default = None) |
| | create_hand_pose: bool, optional |
| | Flag for creating a member variable for the pose of the right |
| | hand. (default = True) |
| | hand_pose: torch.tensor, optional, BxP |
| | The default value for the right hand pose member variable. |
| | (default = None) |
| | num_pca_comps: int, optional |
| | The number of PCA components to use for each hand. |
| | (default = 6) |
| | flat_hand_mean: bool, optional |
| | If False, then the pose of the hand is initialized to False. |
| | batch_size: int, optional |
| | The batch size used for creating the member variables |
| | dtype: torch.dtype, optional |
| | The data type for the created variables |
| | vertex_ids: dict, optional |
| | A dictionary containing the indices of the extra vertices that |
| | will be selected |
| | ''' |
| |
|
| | self.num_pca_comps = num_pca_comps |
| | self.is_rhand = is_rhand |
| | |
| | |
| | if data_struct is None: |
| | |
| | if osp.isdir(model_path): |
| | model_fn = 'MANO_{}.{ext}'.format( |
| | 'RIGHT' if is_rhand else 'LEFT', ext=ext) |
| | mano_path = os.path.join(model_path, model_fn) |
| | else: |
| | mano_path = model_path |
| | self.is_rhand = True if 'RIGHT' in os.path.basename( |
| | model_path) else False |
| | assert osp.exists(mano_path), 'Path {} does not exist!'.format( |
| | mano_path) |
| |
|
| | if ext == 'pkl': |
| | with open(mano_path, 'rb') as mano_file: |
| | model_data = pickle.load(mano_file, encoding='latin1') |
| | elif ext == 'npz': |
| | model_data = np.load(mano_path, allow_pickle=True) |
| | else: |
| | raise ValueError('Unknown extension: {}'.format(ext)) |
| | data_struct = Struct(**model_data) |
| |
|
| | if vertex_ids is None: |
| | vertex_ids = VERTEX_IDS['smplh'] |
| |
|
| | super(MANO, self).__init__(model_path=model_path, |
| | data_struct=data_struct, |
| | batch_size=batch_size, |
| | vertex_ids=vertex_ids, |
| | use_compressed=use_compressed, |
| | dtype=dtype, |
| | ext=ext, |
| | **kwargs) |
| |
|
| | |
| | self.vertex_joint_selector.extra_joints_idxs = to_tensor( |
| | list(VERTEX_IDS['mano'].values()), dtype=torch.long) |
| |
|
| | self.use_pca = use_pca |
| | self.num_pca_comps = num_pca_comps |
| | if self.num_pca_comps == 45: |
| | self.use_pca = False |
| | self.flat_hand_mean = flat_hand_mean |
| |
|
| | hand_components = data_struct.hands_components[:num_pca_comps] |
| |
|
| | self.np_hand_components = hand_components |
| |
|
| | if self.use_pca: |
| | self.register_buffer('hand_components', |
| | torch.tensor(hand_components, dtype=dtype)) |
| |
|
| | if self.flat_hand_mean: |
| | hand_mean = np.zeros_like(data_struct.hands_mean) |
| | else: |
| | hand_mean = data_struct.hands_mean |
| |
|
| | self.register_buffer('hand_mean', to_tensor(hand_mean, |
| | dtype=self.dtype)) |
| |
|
| | |
| | hand_pose_dim = num_pca_comps if use_pca else 3 * self.NUM_HAND_JOINTS |
| | if create_hand_pose: |
| | if hand_pose is None: |
| | default_hand_pose = torch.zeros([batch_size, hand_pose_dim], |
| | dtype=dtype) |
| | else: |
| | default_hand_pose = torch.tensor(hand_pose, dtype=dtype) |
| |
|
| | hand_pose_param = nn.Parameter(default_hand_pose, |
| | requires_grad=True) |
| | self.register_parameter('hand_pose', hand_pose_param) |
| |
|
| | |
| | pose_mean = self.create_mean_pose(data_struct, |
| | flat_hand_mean=flat_hand_mean) |
| | pose_mean_tensor = pose_mean.clone().to(dtype) |
| | |
| | self.register_buffer('pose_mean', pose_mean_tensor) |
| |
|
| | def name(self) -> str: |
| | return 'MANO' |
| |
|
| | def create_mean_pose(self, data_struct, flat_hand_mean=False): |
| | |
| | |
| | global_orient_mean = torch.zeros([3], dtype=self.dtype) |
| | pose_mean = torch.cat([global_orient_mean, self.hand_mean], dim=0) |
| | return pose_mean |
| |
|
| | def extra_repr(self): |
| | msg = [super(MANO, self).extra_repr()] |
| | if self.use_pca: |
| | msg.append(f'Number of PCA components: {self.num_pca_comps}') |
| | msg.append(f'Flat hand mean: {self.flat_hand_mean}') |
| | return '\n'.join(msg) |
| |
|
| | def forward(self, |
| | betas: Optional[Tensor] = None, |
| | global_orient: Optional[Tensor] = None, |
| | hand_pose: Optional[Tensor] = None, |
| | transl: Optional[Tensor] = None, |
| | return_verts: bool = True, |
| | return_full_pose: bool = False, |
| | **kwargs) -> MANOOutput: |
| | ''' Forward pass for the MANO model |
| | ''' |
| | |
| | |
| | global_orient = (global_orient |
| | if global_orient is not None else self.global_orient) |
| | betas = betas if betas is not None else self.betas |
| | hand_pose = (hand_pose if hand_pose is not None else self.hand_pose) |
| |
|
| | apply_trans = transl is not None or hasattr(self, 'transl') |
| | if transl is None: |
| | if hasattr(self, 'transl'): |
| | transl = self.transl |
| |
|
| | if self.use_pca: |
| | hand_pose = torch.einsum('bi,ij->bj', |
| | [hand_pose, self.hand_components]) |
| |
|
| | full_pose = torch.cat([global_orient, hand_pose], dim=1) |
| | full_pose += self.pose_mean |
| |
|
| | vertices, joints = lbs( |
| | betas, |
| | full_pose, |
| | self.v_template, |
| | self.shapedirs, |
| | self.posedirs, |
| | self.J_regressor, |
| | self.parents, |
| | self.lbs_weights, |
| | pose2rot=True, |
| | ) |
| |
|
| | |
| | |
| |
|
| | if self.joint_mapper is not None: |
| | joints = self.joint_mapper(joints) |
| |
|
| | if apply_trans: |
| | joints = joints + transl.unsqueeze(dim=1) |
| | vertices = vertices + transl.unsqueeze(dim=1) |
| |
|
| | output = MANOOutput(vertices=vertices if return_verts else None, |
| | joints=joints if return_verts else None, |
| | betas=betas, |
| | global_orient=global_orient, |
| | hand_pose=hand_pose, |
| | full_pose=full_pose if return_full_pose else None) |
| |
|
| | return output |
| |
|
| |
|
| | class MANOLayer(MANO): |
| |
|
| | def __init__(self, *args, **kwargs) -> None: |
| | ''' MANO as a layer model constructor |
| | ''' |
| | super(MANOLayer, self).__init__(create_global_orient=False, |
| | create_hand_pose=False, |
| | create_betas=False, |
| | create_transl=False, |
| | *args, |
| | **kwargs) |
| |
|
| | def name(self) -> str: |
| | return 'MANO' |
| |
|
| | def forward(self, |
| | betas: Optional[Tensor] = None, |
| | global_orient: Optional[Tensor] = None, |
| | hand_pose: Optional[Tensor] = None, |
| | transl: Optional[Tensor] = None, |
| | return_verts: bool = True, |
| | return_full_pose: bool = False, |
| | **kwargs) -> MANOOutput: |
| | ''' Forward pass for the MANO model |
| | ''' |
| | device, dtype = self.shapedirs.device, self.shapedirs.dtype |
| | if global_orient is None: |
| | batch_size = 1 |
| | global_orient = torch.eye(3, device=device, dtype=dtype).view( |
| | 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() |
| | else: |
| | batch_size = global_orient.shape[0] |
| | if hand_pose is None: |
| | hand_pose = torch.eye(3, device=device, dtype=dtype).view( |
| | 1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous() |
| | if betas is None: |
| | betas = torch.zeros([batch_size, self.num_betas], |
| | dtype=dtype, |
| | device=device) |
| | if transl is None: |
| | transl = torch.zeros([batch_size, 3], dtype=dtype, device=device) |
| |
|
| | full_pose = torch.cat([global_orient, hand_pose], dim=1) |
| | vertices, joints = lbs(betas, |
| | full_pose, |
| | self.v_template, |
| | self.shapedirs, |
| | self.posedirs, |
| | self.J_regressor, |
| | self.parents, |
| | self.lbs_weights, |
| | pose2rot=False) |
| |
|
| | if self.joint_mapper is not None: |
| | joints = self.joint_mapper(joints) |
| |
|
| | if transl is not None: |
| | joints = joints + transl.unsqueeze(dim=1) |
| | vertices = vertices + transl.unsqueeze(dim=1) |
| |
|
| | output = MANOOutput(vertices=vertices if return_verts else None, |
| | joints=joints if return_verts else None, |
| | betas=betas, |
| | global_orient=global_orient, |
| | hand_pose=hand_pose, |
| | full_pose=full_pose if return_full_pose else None) |
| |
|
| | return output |
| |
|
| |
|
| | class FLAME(SMPL): |
| | NUM_JOINTS = 5 |
| | SHAPE_SPACE_DIM = 300 |
| | EXPRESSION_SPACE_DIM = 100 |
| | NECK_IDX = 0 |
| |
|
| | def __init__(self, |
| | model_path: str, |
| | data_struct=None, |
| | num_expression_coeffs=10, |
| | create_expression: bool = True, |
| | expression: Optional[Tensor] = None, |
| | create_neck_pose: bool = True, |
| | neck_pose: Optional[Tensor] = None, |
| | create_jaw_pose: bool = True, |
| | jaw_pose: Optional[Tensor] = None, |
| | create_leye_pose: bool = True, |
| | leye_pose: Optional[Tensor] = None, |
| | create_reye_pose=True, |
| | reye_pose: Optional[Tensor] = None, |
| | use_face_contour=False, |
| | batch_size: int = 1, |
| | gender: str = 'neutral', |
| | dtype: torch.dtype = torch.float32, |
| | ext='pkl', |
| | **kwargs) -> None: |
| | ''' FLAME model constructor |
| | |
| | Parameters |
| | ---------- |
| | model_path: str |
| | The path to the folder or to the file where the model |
| | parameters are stored |
| | num_expression_coeffs: int, optional |
| | Number of expression components to use |
| | (default = 10). |
| | create_expression: bool, optional |
| | Flag for creating a member variable for the expression space |
| | (default = True). |
| | expression: torch.tensor, optional, Bx10 |
| | The default value for the expression member variable. |
| | (default = None) |
| | create_neck_pose: bool, optional |
| | Flag for creating a member variable for the neck pose. |
| | (default = False) |
| | neck_pose: torch.tensor, optional, Bx3 |
| | The default value for the neck pose variable. |
| | (default = None) |
| | create_jaw_pose: bool, optional |
| | Flag for creating a member variable for the jaw pose. |
| | (default = False) |
| | jaw_pose: torch.tensor, optional, Bx3 |
| | The default value for the jaw pose variable. |
| | (default = None) |
| | create_leye_pose: bool, optional |
| | Flag for creating a member variable for the left eye pose. |
| | (default = False) |
| | leye_pose: torch.tensor, optional, Bx10 |
| | The default value for the left eye pose variable. |
| | (default = None) |
| | create_reye_pose: bool, optional |
| | Flag for creating a member variable for the right eye pose. |
| | (default = False) |
| | reye_pose: torch.tensor, optional, Bx10 |
| | The default value for the right eye pose variable. |
| | (default = None) |
| | use_face_contour: bool, optional |
| | Whether to compute the keypoints that form the facial contour |
| | batch_size: int, optional |
| | The batch size used for creating the member variables |
| | gender: str, optional |
| | Which gender to load |
| | dtype: torch.dtype |
| | The data type for the created variables |
| | ''' |
| | model_fn = f'FLAME_{gender.upper()}.{ext}' |
| | flame_path = os.path.join(model_path, model_fn) |
| | assert osp.exists(flame_path), 'Path {} does not exist!'.format( |
| | flame_path) |
| | if ext == 'npz': |
| | file_data = np.load(flame_path, allow_pickle=True) |
| | elif ext == 'pkl': |
| | with open(flame_path, 'rb') as smpl_file: |
| | file_data = pickle.load(smpl_file, encoding='latin1') |
| | else: |
| | raise ValueError('Unknown extension: {}'.format(ext)) |
| | data_struct = Struct(**file_data) |
| |
|
| | super(FLAME, self).__init__(model_path=model_path, |
| | data_struct=data_struct, |
| | dtype=dtype, |
| | batch_size=batch_size, |
| | gender=gender, |
| | ext=ext, |
| | **kwargs) |
| |
|
| | self.use_face_contour = use_face_contour |
| |
|
| | self.vertex_joint_selector.extra_joints_idxs = to_tensor( |
| | [], dtype=torch.long) |
| |
|
| | if create_neck_pose: |
| | if neck_pose is None: |
| | default_neck_pose = torch.zeros([batch_size, 3], dtype=dtype) |
| | else: |
| | default_neck_pose = torch.tensor(neck_pose, dtype=dtype) |
| | neck_pose_param = nn.Parameter(default_neck_pose, |
| | requires_grad=True) |
| | self.register_parameter('neck_pose', neck_pose_param) |
| |
|
| | if create_jaw_pose: |
| | if jaw_pose is None: |
| | default_jaw_pose = torch.zeros([batch_size, 3], dtype=dtype) |
| | else: |
| | default_jaw_pose = torch.tensor(jaw_pose, dtype=dtype) |
| | jaw_pose_param = nn.Parameter(default_jaw_pose, requires_grad=True) |
| | self.register_parameter('jaw_pose', jaw_pose_param) |
| |
|
| | if create_leye_pose: |
| | if leye_pose is None: |
| | default_leye_pose = torch.zeros([batch_size, 3], dtype=dtype) |
| | else: |
| | default_leye_pose = torch.tensor(leye_pose, dtype=dtype) |
| | leye_pose_param = nn.Parameter(default_leye_pose, |
| | requires_grad=True) |
| | self.register_parameter('leye_pose', leye_pose_param) |
| |
|
| | if create_reye_pose: |
| | if reye_pose is None: |
| | default_reye_pose = torch.zeros([batch_size, 3], dtype=dtype) |
| | else: |
| | default_reye_pose = torch.tensor(reye_pose, dtype=dtype) |
| | reye_pose_param = nn.Parameter(default_reye_pose, |
| | requires_grad=True) |
| | self.register_parameter('reye_pose', reye_pose_param) |
| |
|
| | shapedirs = data_struct.shapedirs |
| | if len(shapedirs.shape) < 3: |
| | shapedirs = shapedirs[:, :, None] |
| | if (shapedirs.shape[-1] < |
| | self.SHAPE_SPACE_DIM + self.EXPRESSION_SPACE_DIM): |
| | |
| | |
| | expr_start_idx = 10 |
| | expr_end_idx = 20 |
| | num_expression_coeffs = min(num_expression_coeffs, 10) |
| | else: |
| | expr_start_idx = self.SHAPE_SPACE_DIM |
| | expr_end_idx = self.SHAPE_SPACE_DIM + num_expression_coeffs |
| | num_expression_coeffs = min(num_expression_coeffs, |
| | self.EXPRESSION_SPACE_DIM) |
| |
|
| | self._num_expression_coeffs = num_expression_coeffs |
| |
|
| | expr_dirs = shapedirs[:, :, expr_start_idx:expr_end_idx] |
| | self.register_buffer('expr_dirs', |
| | to_tensor(to_np(expr_dirs), dtype=dtype)) |
| |
|
| | if create_expression: |
| | if expression is None: |
| | default_expression = torch.zeros( |
| | [batch_size, self.num_expression_coeffs], dtype=dtype) |
| | else: |
| | default_expression = torch.tensor(expression, dtype=dtype) |
| | expression_param = nn.Parameter(default_expression, |
| | requires_grad=True) |
| | self.register_parameter('expression', expression_param) |
| |
|
| | |
| | |
| | landmark_bcoord_filename = osp.join(model_path, |
| | 'flame_static_embedding.pkl') |
| |
|
| | with open(landmark_bcoord_filename, 'rb') as fp: |
| | landmarks_data = pickle.load(fp, encoding='latin1') |
| |
|
| | lmk_faces_idx = landmarks_data['lmk_face_idx'].astype(np.int64) |
| | self.register_buffer('lmk_faces_idx', |
| | torch.tensor(lmk_faces_idx, dtype=torch.long)) |
| | lmk_bary_coords = landmarks_data['lmk_b_coords'] |
| | self.register_buffer('lmk_bary_coords', |
| | torch.tensor(lmk_bary_coords, dtype=dtype)) |
| | if self.use_face_contour: |
| | face_contour_path = os.path.join(model_path, |
| | 'flame_dynamic_embedding.npy') |
| | contour_embeddings = np.load(face_contour_path, |
| | allow_pickle=True, |
| | encoding='latin1')[()] |
| |
|
| | dynamic_lmk_faces_idx = np.array( |
| | contour_embeddings['lmk_face_idx'], dtype=np.int64) |
| | dynamic_lmk_faces_idx = torch.tensor(dynamic_lmk_faces_idx, |
| | dtype=torch.long) |
| | self.register_buffer('dynamic_lmk_faces_idx', |
| | dynamic_lmk_faces_idx) |
| |
|
| | dynamic_lmk_b_coords = torch.tensor( |
| | contour_embeddings['lmk_b_coords'], dtype=dtype) |
| | self.register_buffer('dynamic_lmk_bary_coords', |
| | dynamic_lmk_b_coords) |
| |
|
| | neck_kin_chain = find_joint_kin_chain(self.NECK_IDX, self.parents) |
| | self.register_buffer( |
| | 'neck_kin_chain', torch.tensor(neck_kin_chain, |
| | dtype=torch.long)) |
| |
|
| | @property |
| | def num_expression_coeffs(self): |
| | return self._num_expression_coeffs |
| |
|
| | def name(self) -> str: |
| | return 'FLAME' |
| |
|
| | def extra_repr(self): |
| | msg = [ |
| | super(FLAME, self).extra_repr(), |
| | f'Number of Expression Coefficients: {self.num_expression_coeffs}', |
| | f'Use face contour: {self.use_face_contour}', |
| | ] |
| | return '\n'.join(msg) |
| |
|
| | def forward(self, |
| | betas: Optional[Tensor] = None, |
| | global_orient: Optional[Tensor] = None, |
| | neck_pose: Optional[Tensor] = None, |
| | transl: Optional[Tensor] = None, |
| | expression: Optional[Tensor] = None, |
| | jaw_pose: Optional[Tensor] = None, |
| | leye_pose: Optional[Tensor] = None, |
| | reye_pose: Optional[Tensor] = None, |
| | return_verts: bool = True, |
| | return_full_pose: bool = False, |
| | pose2rot: bool = True, |
| | **kwargs) -> FLAMEOutput: |
| | ''' |
| | Forward pass for the SMPLX model |
| | |
| | Parameters |
| | ---------- |
| | global_orient: torch.tensor, optional, shape Bx3 |
| | If given, ignore the member variable and use it as the global |
| | rotation of the body. Useful if someone wishes to predicts this |
| | with an external model. (default=None) |
| | betas: torch.tensor, optional, shape Bx10 |
| | If given, ignore the member variable `betas` and use it |
| | instead. For example, it can used if shape parameters |
| | `betas` are predicted from some external model. |
| | (default=None) |
| | expression: torch.tensor, optional, shape Bx10 |
| | If given, ignore the member variable `expression` and use it |
| | instead. For example, it can used if expression parameters |
| | `expression` are predicted from some external model. |
| | jaw_pose: torch.tensor, optional, shape Bx3 |
| | If given, ignore the member variable `jaw_pose` and |
| | use this instead. It should either joint rotations in |
| | axis-angle format. |
| | jaw_pose: torch.tensor, optional, shape Bx3 |
| | If given, ignore the member variable `jaw_pose` and |
| | use this instead. It should either joint rotations in |
| | axis-angle format. |
| | transl: torch.tensor, optional, shape Bx3 |
| | If given, ignore the member variable `transl` and use it |
| | instead. For example, it can used if the translation |
| | `transl` is predicted from some external model. |
| | (default=None) |
| | return_verts: bool, optional |
| | Return the vertices. (default=True) |
| | return_full_pose: bool, optional |
| | Returns the full axis-angle pose vector (default=False) |
| | |
| | Returns |
| | ------- |
| | output: ModelOutput |
| | A named tuple of type `ModelOutput` |
| | ''' |
| |
|
| | |
| | |
| | global_orient = (global_orient |
| | if global_orient is not None else self.global_orient) |
| | jaw_pose = jaw_pose if jaw_pose is not None else self.jaw_pose |
| | neck_pose = neck_pose if neck_pose is not None else self.neck_pose |
| |
|
| | leye_pose = leye_pose if leye_pose is not None else self.leye_pose |
| | reye_pose = reye_pose if reye_pose is not None else self.reye_pose |
| |
|
| | betas = betas if betas is not None else self.betas |
| | expression = expression if expression is not None else self.expression |
| |
|
| | apply_trans = transl is not None or hasattr(self, 'transl') |
| | if transl is None: |
| | if hasattr(self, 'transl'): |
| | transl = self.transl |
| |
|
| | full_pose = torch.cat( |
| | [global_orient, neck_pose, jaw_pose, leye_pose, reye_pose], dim=1) |
| |
|
| | batch_size = max(betas.shape[0], global_orient.shape[0], |
| | jaw_pose.shape[0]) |
| | |
| | scale = int(batch_size / betas.shape[0]) |
| | if scale > 1: |
| | betas = betas.expand(scale, -1) |
| | shape_components = torch.cat([betas, expression], dim=-1) |
| | shapedirs = torch.cat([self.shapedirs, self.expr_dirs], dim=-1) |
| |
|
| | vertices, joints = lbs( |
| | shape_components, |
| | full_pose, |
| | self.v_template, |
| | shapedirs, |
| | self.posedirs, |
| | self.J_regressor, |
| | self.parents, |
| | self.lbs_weights, |
| | pose2rot=pose2rot, |
| | ) |
| |
|
| | lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).expand( |
| | batch_size, -1).contiguous() |
| | lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).repeat( |
| | self.batch_size, 1, 1) |
| | if self.use_face_contour: |
| | lmk_idx_and_bcoords = find_dynamic_lmk_idx_and_bcoords( |
| | vertices, |
| | full_pose, |
| | self.dynamic_lmk_faces_idx, |
| | self.dynamic_lmk_bary_coords, |
| | self.neck_kin_chain, |
| | pose2rot=True, |
| | ) |
| | dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords |
| | lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1) |
| | lmk_bary_coords = torch.cat([ |
| | lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords |
| | ], 1) |
| |
|
| | landmarks = vertices2landmarks(vertices, self.faces_tensor, |
| | lmk_faces_idx, lmk_bary_coords) |
| |
|
| | |
| | joints = self.vertex_joint_selector(vertices, joints) |
| | |
| | joints = torch.cat([joints, landmarks], dim=1) |
| |
|
| | |
| | if self.joint_mapper is not None: |
| | joints = self.joint_mapper(joints=joints, vertices=vertices) |
| |
|
| | if apply_trans: |
| | joints += transl.unsqueeze(dim=1) |
| | vertices += transl.unsqueeze(dim=1) |
| |
|
| | output = FLAMEOutput(vertices=vertices if return_verts else None, |
| | joints=joints, |
| | betas=betas, |
| | expression=expression, |
| | global_orient=global_orient, |
| | neck_pose=neck_pose, |
| | jaw_pose=jaw_pose, |
| | full_pose=full_pose if return_full_pose else None) |
| | return output |
| |
|
| |
|
| | class FLAMELayer(FLAME): |
| |
|
| | def __init__(self, *args, **kwargs) -> None: |
| | ''' FLAME as a layer model constructor ''' |
| | super(FLAMELayer, self).__init__(create_betas=False, |
| | create_expression=False, |
| | create_global_orient=False, |
| | create_neck_pose=False, |
| | create_jaw_pose=False, |
| | create_leye_pose=False, |
| | create_reye_pose=False, |
| | *args, |
| | **kwargs) |
| |
|
| | def forward(self, |
| | betas: Optional[Tensor] = None, |
| | global_orient: Optional[Tensor] = None, |
| | neck_pose: Optional[Tensor] = None, |
| | transl: Optional[Tensor] = None, |
| | expression: Optional[Tensor] = None, |
| | jaw_pose: Optional[Tensor] = None, |
| | leye_pose: Optional[Tensor] = None, |
| | reye_pose: Optional[Tensor] = None, |
| | return_verts: bool = True, |
| | return_full_pose: bool = False, |
| | pose2rot: bool = True, |
| | **kwargs) -> FLAMEOutput: |
| | ''' |
| | Forward pass for the SMPLX model |
| | |
| | Parameters |
| | ---------- |
| | global_orient: torch.tensor, optional, shape Bx3x3 |
| | Global rotation of the body. Useful if someone wishes to |
| | predicts this with an external model. It is expected to be in |
| | rotation matrix format. (default=None) |
| | betas: torch.tensor, optional, shape BxN_b |
| | Shape parameters. For example, it can used if shape parameters |
| | `betas` are predicted from some external model. |
| | (default=None) |
| | expression: torch.tensor, optional, shape BxN_e |
| | If given, ignore the member variable `expression` and use it |
| | instead. For example, it can used if expression parameters |
| | `expression` are predicted from some external model. |
| | jaw_pose: torch.tensor, optional, shape Bx3x3 |
| | Jaw pose. It should either joint rotations in |
| | rotation matrix format. |
| | transl: torch.tensor, optional, shape Bx3 |
| | Translation vector of the body. |
| | For example, it can used if the translation |
| | `transl` is predicted from some external model. |
| | (default=None) |
| | return_verts: bool, optional |
| | Return the vertices. (default=True) |
| | return_full_pose: bool, optional |
| | Returns the full axis-angle pose vector (default=False) |
| | |
| | Returns |
| | ------- |
| | output: ModelOutput |
| | A named tuple of type `ModelOutput` |
| | ''' |
| | device, dtype = self.shapedirs.device, self.shapedirs.dtype |
| | if global_orient is None: |
| | batch_size = 1 |
| | global_orient = torch.eye(3, device=device, dtype=dtype).view( |
| | 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() |
| | else: |
| | batch_size = global_orient.shape[0] |
| | if neck_pose is None: |
| | neck_pose = torch.eye(3, device=device, dtype=dtype).view( |
| | 1, 1, 3, 3).expand(batch_size, 1, -1, -1).contiguous() |
| | if jaw_pose is None: |
| | jaw_pose = torch.eye(3, device=device, dtype=dtype).view( |
| | 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() |
| | if leye_pose is None: |
| | leye_pose = torch.eye(3, device=device, dtype=dtype).view( |
| | 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() |
| | if reye_pose is None: |
| | reye_pose = torch.eye(3, device=device, dtype=dtype).view( |
| | 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() |
| | if betas is None: |
| | betas = torch.zeros([batch_size, self.num_betas], |
| | dtype=dtype, |
| | device=device) |
| | if expression is None: |
| | expression = torch.zeros([batch_size, self.num_expression_coeffs], |
| | dtype=dtype, |
| | device=device) |
| | if transl is None: |
| | transl = torch.zeros([batch_size, 3], dtype=dtype, device=device) |
| |
|
| | full_pose = torch.cat( |
| | [global_orient, neck_pose, jaw_pose, leye_pose, reye_pose], dim=1) |
| |
|
| | shape_components = torch.cat([betas, expression], dim=-1) |
| | shapedirs = torch.cat([self.shapedirs, self.expr_dirs], dim=-1) |
| |
|
| | vertices, joints = lbs( |
| | shape_components, |
| | full_pose, |
| | self.v_template, |
| | shapedirs, |
| | self.posedirs, |
| | self.J_regressor, |
| | self.parents, |
| | self.lbs_weights, |
| | pose2rot=False, |
| | ) |
| |
|
| | lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).expand( |
| | batch_size, -1).contiguous() |
| | lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).repeat( |
| | self.batch_size, 1, 1) |
| | if self.use_face_contour: |
| | lmk_idx_and_bcoords = find_dynamic_lmk_idx_and_bcoords( |
| | vertices, |
| | full_pose, |
| | self.dynamic_lmk_faces_idx, |
| | self.dynamic_lmk_bary_coords, |
| | self.neck_kin_chain, |
| | pose2rot=False, |
| | ) |
| | dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords |
| | lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1) |
| | lmk_bary_coords = torch.cat([ |
| | lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords |
| | ], 1) |
| |
|
| | landmarks = vertices2landmarks(vertices, self.faces_tensor, |
| | lmk_faces_idx, lmk_bary_coords) |
| |
|
| | |
| | joints = self.vertex_joint_selector(vertices, joints) |
| | |
| | joints = torch.cat([joints, landmarks], dim=1) |
| |
|
| | |
| | if self.joint_mapper is not None: |
| | joints = self.joint_mapper(joints=joints, vertices=vertices) |
| |
|
| | joints += transl.unsqueeze(dim=1) |
| | vertices += transl.unsqueeze(dim=1) |
| |
|
| | output = FLAMEOutput(vertices=vertices if return_verts else None, |
| | joints=joints, |
| | betas=betas, |
| | expression=expression, |
| | global_orient=global_orient, |
| | neck_pose=neck_pose, |
| | jaw_pose=jaw_pose, |
| | full_pose=full_pose if return_full_pose else None) |
| | return output |
| |
|
| |
|
| | def build_layer( |
| | model_path: str, |
| | model_type: str = 'smpl', |
| | **kwargs |
| | ) -> Union[SMPLLayer, SMPLHLayer, SMPLXLayer, MANOLayer, FLAMELayer]: |
| | ''' Method for creating a model from a path and a model type |
| | |
| | Parameters |
| | ---------- |
| | model_path: str |
| | Either the path to the model you wish to load or a folder, |
| | where each subfolder contains the differents types, i.e.: |
| | model_path: |
| | | |
| | |-- smpl |
| | |-- SMPL_FEMALE |
| | |-- SMPL_NEUTRAL |
| | |-- SMPL_MALE |
| | |-- smplh |
| | |-- SMPLH_FEMALE |
| | |-- SMPLH_MALE |
| | |-- smplx |
| | |-- SMPLX_FEMALE |
| | |-- SMPLX_NEUTRAL |
| | |-- SMPLX_MALE |
| | |-- mano |
| | |-- MANO RIGHT |
| | |-- MANO LEFT |
| | |-- flame |
| | |-- FLAME_FEMALE |
| | |-- FLAME_MALE |
| | |-- FLAME_NEUTRAL |
| | |
| | model_type: str, optional |
| | When model_path is a folder, then this parameter specifies the |
| | type of model to be loaded |
| | **kwargs: dict |
| | Keyword arguments |
| | |
| | Returns |
| | ------- |
| | body_model: nn.Module |
| | The PyTorch module that implements the corresponding body model |
| | Raises |
| | ------ |
| | ValueError: In case the model type is not one of SMPL, SMPLH, |
| | SMPLX, MANO or FLAME |
| | ''' |
| |
|
| | if osp.isdir(model_path): |
| | model_path = os.path.join(model_path, model_type) |
| | else: |
| | model_type = osp.basename(model_path).split('_')[0].lower() |
| |
|
| | if model_type.lower() == 'smpl': |
| | return SMPLLayer(model_path, **kwargs) |
| | elif model_type.lower() == 'smplh': |
| | return SMPLHLayer(model_path, **kwargs) |
| | elif model_type.lower() == 'smplx': |
| | return SMPLXLayer(model_path, **kwargs) |
| | elif 'mano' in model_type.lower(): |
| | return MANOLayer(model_path, **kwargs) |
| | elif 'flame' in model_type.lower(): |
| | return FLAMELayer(model_path, **kwargs) |
| | else: |
| | raise ValueError(f'Unknown model type {model_type}, exiting!') |
| |
|
| |
|
| | def create(model_path: str, |
| | model_type: str = 'smpl', |
| | **kwargs) -> Union[SMPL, SMPLH, SMPLX, MANO, FLAME]: |
| | ''' Method for creating a model from a path and a model type |
| | |
| | Parameters |
| | ---------- |
| | model_path: str |
| | Either the path to the model you wish to load or a folder, |
| | where each subfolder contains the differents types, i.e.: |
| | model_path: |
| | | |
| | |-- smpl |
| | |-- SMPL_FEMALE |
| | |-- SMPL_NEUTRAL |
| | |-- SMPL_MALE |
| | |-- smplh |
| | |-- SMPLH_FEMALE |
| | |-- SMPLH_MALE |
| | |-- smplx |
| | |-- SMPLX_FEMALE |
| | |-- SMPLX_NEUTRAL |
| | |-- SMPLX_MALE |
| | |-- mano |
| | |-- MANO RIGHT |
| | |-- MANO LEFT |
| | |
| | model_type: str, optional |
| | When model_path is a folder, then this parameter specifies the |
| | type of model to be loaded |
| | **kwargs: dict |
| | Keyword arguments |
| | |
| | Returns |
| | ------- |
| | body_model: nn.Module |
| | The PyTorch module that implements the corresponding body model |
| | Raises |
| | ------ |
| | ValueError: In case the model type is not one of SMPL, SMPLH, |
| | SMPLX, MANO or FLAME |
| | ''' |
| |
|
| | |
| | if osp.isdir(model_path): |
| | model_path = os.path.join(model_path, model_type) |
| | else: |
| | model_type = osp.basename(model_path).split('_')[0].lower() |
| |
|
| | if model_type.lower() == 'smpl': |
| | return SMPL(model_path, **kwargs) |
| | elif model_type.lower() == 'smplh': |
| | return SMPLH(model_path, **kwargs) |
| | elif model_type.lower() == 'smplx': |
| | return SMPLX(model_path, **kwargs) |
| | elif 'mano' in model_type.lower(): |
| | return MANO(model_path, **kwargs) |
| | elif 'flame' in model_type.lower(): |
| | return FLAME(model_path, **kwargs) |
| | else: |
| | raise ValueError(f'Unknown model type {model_type}, exiting!') |
| |
|