| | """ |
| | original from https://github.com/vchoutas/smplx |
| | modified by Vassilis and Yao |
| | """ |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import numpy as np |
| | import pickle |
| |
|
| | from .lbs import ( |
| | Struct, |
| | to_tensor, |
| | to_np, |
| | lbs, |
| | vertices2landmarks, |
| | JointsFromVerticesSelector, |
| | find_dynamic_lmk_idx_and_bcoords, |
| | ) |
| |
|
| | |
| | J14_NAMES = [ |
| | "right_ankle", |
| | "right_knee", |
| | "right_hip", |
| | "left_hip", |
| | "left_knee", |
| | "left_ankle", |
| | "right_wrist", |
| | "right_elbow", |
| | "right_shoulder", |
| | "left_shoulder", |
| | "left_elbow", |
| | "left_wrist", |
| | "neck", |
| | "head", |
| | ] |
| | SMPLX_names = [ |
| | "pelvis", |
| | "left_hip", |
| | "right_hip", |
| | "spine1", |
| | "left_knee", |
| | "right_knee", |
| | "spine2", |
| | "left_ankle", |
| | "right_ankle", |
| | "spine3", |
| | "left_foot", |
| | "right_foot", |
| | "neck", |
| | "left_collar", |
| | "right_collar", |
| | "head", |
| | "left_shoulder", |
| | "right_shoulder", |
| | "left_elbow", |
| | "right_elbow", |
| | "left_wrist", |
| | "right_wrist", |
| | "jaw", |
| | "left_eye_smplx", |
| | "right_eye_smplx", |
| | "left_index1", |
| | "left_index2", |
| | "left_index3", |
| | "left_middle1", |
| | "left_middle2", |
| | "left_middle3", |
| | "left_pinky1", |
| | "left_pinky2", |
| | "left_pinky3", |
| | "left_ring1", |
| | "left_ring2", |
| | "left_ring3", |
| | "left_thumb1", |
| | "left_thumb2", |
| | "left_thumb3", |
| | "right_index1", |
| | "right_index2", |
| | "right_index3", |
| | "right_middle1", |
| | "right_middle2", |
| | "right_middle3", |
| | "right_pinky1", |
| | "right_pinky2", |
| | "right_pinky3", |
| | "right_ring1", |
| | "right_ring2", |
| | "right_ring3", |
| | "right_thumb1", |
| | "right_thumb2", |
| | "right_thumb3", |
| | "right_eye_brow1", |
| | "right_eye_brow2", |
| | "right_eye_brow3", |
| | "right_eye_brow4", |
| | "right_eye_brow5", |
| | "left_eye_brow5", |
| | "left_eye_brow4", |
| | "left_eye_brow3", |
| | "left_eye_brow2", |
| | "left_eye_brow1", |
| | "nose1", |
| | "nose2", |
| | "nose3", |
| | "nose4", |
| | "right_nose_2", |
| | "right_nose_1", |
| | "nose_middle", |
| | "left_nose_1", |
| | "left_nose_2", |
| | "right_eye1", |
| | "right_eye2", |
| | "right_eye3", |
| | "right_eye4", |
| | "right_eye5", |
| | "right_eye6", |
| | "left_eye4", |
| | "left_eye3", |
| | "left_eye2", |
| | "left_eye1", |
| | "left_eye6", |
| | "left_eye5", |
| | "right_mouth_1", |
| | "right_mouth_2", |
| | "right_mouth_3", |
| | "mouth_top", |
| | "left_mouth_3", |
| | "left_mouth_2", |
| | "left_mouth_1", |
| | "left_mouth_5", |
| | "left_mouth_4", |
| | "mouth_bottom", |
| | "right_mouth_4", |
| | "right_mouth_5", |
| | "right_lip_1", |
| | "right_lip_2", |
| | "lip_top", |
| | "left_lip_2", |
| | "left_lip_1", |
| | "left_lip_3", |
| | "lip_bottom", |
| | "right_lip_3", |
| | "right_contour_1", |
| | "right_contour_2", |
| | "right_contour_3", |
| | "right_contour_4", |
| | "right_contour_5", |
| | "right_contour_6", |
| | "right_contour_7", |
| | "right_contour_8", |
| | "contour_middle", |
| | "left_contour_8", |
| | "left_contour_7", |
| | "left_contour_6", |
| | "left_contour_5", |
| | "left_contour_4", |
| | "left_contour_3", |
| | "left_contour_2", |
| | "left_contour_1", |
| | "head_top", |
| | "left_big_toe", |
| | "left_ear", |
| | "left_eye", |
| | "left_heel", |
| | "left_index", |
| | "left_middle", |
| | "left_pinky", |
| | "left_ring", |
| | "left_small_toe", |
| | "left_thumb", |
| | "nose", |
| | "right_big_toe", |
| | "right_ear", |
| | "right_eye", |
| | "right_heel", |
| | "right_index", |
| | "right_middle", |
| | "right_pinky", |
| | "right_ring", |
| | "right_small_toe", |
| | "right_thumb", |
| | ] |
| | extra_names = [ |
| | "head_top", |
| | "left_big_toe", |
| | "left_ear", |
| | "left_eye", |
| | "left_heel", |
| | "left_index", |
| | "left_middle", |
| | "left_pinky", |
| | "left_ring", |
| | "left_small_toe", |
| | "left_thumb", |
| | "nose", |
| | "right_big_toe", |
| | "right_ear", |
| | "right_eye", |
| | "right_heel", |
| | "right_index", |
| | "right_middle", |
| | "right_pinky", |
| | "right_ring", |
| | "right_small_toe", |
| | "right_thumb", |
| | ] |
| | SMPLX_names += extra_names |
| |
|
| | part_indices = {} |
| | part_indices["body"] = np.array([ |
| | 0, |
| | 1, |
| | 2, |
| | 3, |
| | 4, |
| | 5, |
| | 6, |
| | 7, |
| | 8, |
| | 9, |
| | 10, |
| | 11, |
| | 12, |
| | 13, |
| | 14, |
| | 15, |
| | 16, |
| | 17, |
| | 18, |
| | 19, |
| | 20, |
| | 21, |
| | 22, |
| | 23, |
| | 24, |
| | 123, |
| | 124, |
| | 125, |
| | 126, |
| | 127, |
| | 132, |
| | 134, |
| | 135, |
| | 136, |
| | 137, |
| | 138, |
| | 143, |
| | ]) |
| | part_indices["torso"] = np.array([ |
| | 0, |
| | 1, |
| | 2, |
| | 3, |
| | 6, |
| | 9, |
| | 12, |
| | 13, |
| | 14, |
| | 15, |
| | 16, |
| | 17, |
| | 18, |
| | 19, |
| | 22, |
| | 23, |
| | 24, |
| | 55, |
| | 56, |
| | 57, |
| | 58, |
| | 59, |
| | 76, |
| | 77, |
| | 78, |
| | 79, |
| | 80, |
| | 81, |
| | 82, |
| | 83, |
| | 84, |
| | 85, |
| | 86, |
| | 87, |
| | 88, |
| | 89, |
| | 90, |
| | 91, |
| | 92, |
| | 93, |
| | 94, |
| | 95, |
| | 96, |
| | 97, |
| | 98, |
| | 99, |
| | 100, |
| | 101, |
| | 102, |
| | 103, |
| | 104, |
| | 105, |
| | 106, |
| | 107, |
| | 108, |
| | 109, |
| | 110, |
| | 111, |
| | 112, |
| | 113, |
| | 114, |
| | 115, |
| | 116, |
| | 117, |
| | 118, |
| | 119, |
| | 120, |
| | 121, |
| | 122, |
| | 123, |
| | 124, |
| | 125, |
| | 126, |
| | 127, |
| | 128, |
| | 129, |
| | 130, |
| | 131, |
| | 132, |
| | 133, |
| | 134, |
| | 135, |
| | 136, |
| | 137, |
| | 138, |
| | 139, |
| | 140, |
| | 141, |
| | 142, |
| | 143, |
| | 144, |
| | ]) |
| | part_indices["head"] = np.array([ |
| | 12, |
| | 15, |
| | 22, |
| | 23, |
| | 24, |
| | 55, |
| | 56, |
| | 57, |
| | 58, |
| | 59, |
| | 60, |
| | 61, |
| | 62, |
| | 63, |
| | 64, |
| | 65, |
| | 66, |
| | 67, |
| | 68, |
| | 69, |
| | 70, |
| | 71, |
| | 72, |
| | 73, |
| | 74, |
| | 75, |
| | 76, |
| | 77, |
| | 78, |
| | 79, |
| | 80, |
| | 81, |
| | 82, |
| | 83, |
| | 84, |
| | 85, |
| | 86, |
| | 87, |
| | 88, |
| | 89, |
| | 90, |
| | 91, |
| | 92, |
| | 93, |
| | 94, |
| | 95, |
| | 96, |
| | 97, |
| | 98, |
| | 99, |
| | 100, |
| | 101, |
| | 102, |
| | 103, |
| | 104, |
| | 105, |
| | 106, |
| | 107, |
| | 108, |
| | 109, |
| | 110, |
| | 111, |
| | 112, |
| | 113, |
| | 114, |
| | 115, |
| | 116, |
| | 117, |
| | 118, |
| | 119, |
| | 120, |
| | 121, |
| | 122, |
| | 123, |
| | 125, |
| | 126, |
| | 134, |
| | 136, |
| | 137, |
| | ]) |
| | part_indices["face"] = np.array([ |
| | 55, |
| | 56, |
| | 57, |
| | 58, |
| | 59, |
| | 60, |
| | 61, |
| | 62, |
| | 63, |
| | 64, |
| | 65, |
| | 66, |
| | 67, |
| | 68, |
| | 69, |
| | 70, |
| | 71, |
| | 72, |
| | 73, |
| | 74, |
| | 75, |
| | 76, |
| | 77, |
| | 78, |
| | 79, |
| | 80, |
| | 81, |
| | 82, |
| | 83, |
| | 84, |
| | 85, |
| | 86, |
| | 87, |
| | 88, |
| | 89, |
| | 90, |
| | 91, |
| | 92, |
| | 93, |
| | 94, |
| | 95, |
| | 96, |
| | 97, |
| | 98, |
| | 99, |
| | 100, |
| | 101, |
| | 102, |
| | 103, |
| | 104, |
| | 105, |
| | 106, |
| | 107, |
| | 108, |
| | 109, |
| | 110, |
| | 111, |
| | 112, |
| | 113, |
| | 114, |
| | 115, |
| | 116, |
| | 117, |
| | 118, |
| | 119, |
| | 120, |
| | 121, |
| | 122, |
| | ]) |
| | part_indices["upper"] = np.array([ |
| | 12, |
| | 13, |
| | 14, |
| | 55, |
| | 56, |
| | 57, |
| | 58, |
| | 59, |
| | 60, |
| | 61, |
| | 62, |
| | 63, |
| | 64, |
| | 65, |
| | 66, |
| | 67, |
| | 68, |
| | 69, |
| | 70, |
| | 71, |
| | 72, |
| | 73, |
| | 74, |
| | 75, |
| | 76, |
| | 77, |
| | 78, |
| | 79, |
| | 80, |
| | 81, |
| | 82, |
| | 83, |
| | 84, |
| | 85, |
| | 86, |
| | 87, |
| | 88, |
| | 89, |
| | 90, |
| | 91, |
| | 92, |
| | 93, |
| | 94, |
| | 95, |
| | 96, |
| | 97, |
| | 98, |
| | 99, |
| | 100, |
| | 101, |
| | 102, |
| | 103, |
| | 104, |
| | 105, |
| | 106, |
| | 107, |
| | 108, |
| | 109, |
| | 110, |
| | 111, |
| | 112, |
| | 113, |
| | 114, |
| | 115, |
| | 116, |
| | 117, |
| | 118, |
| | 119, |
| | 120, |
| | 121, |
| | 122, |
| | ]) |
| | part_indices["hand"] = np.array([ |
| | 20, |
| | 21, |
| | 25, |
| | 26, |
| | 27, |
| | 28, |
| | 29, |
| | 30, |
| | 31, |
| | 32, |
| | 33, |
| | 34, |
| | 35, |
| | 36, |
| | 37, |
| | 38, |
| | 39, |
| | 40, |
| | 41, |
| | 42, |
| | 43, |
| | 44, |
| | 45, |
| | 46, |
| | 47, |
| | 48, |
| | 49, |
| | 50, |
| | 51, |
| | 52, |
| | 53, |
| | 54, |
| | 128, |
| | 129, |
| | 130, |
| | 131, |
| | 133, |
| | 139, |
| | 140, |
| | 141, |
| | 142, |
| | 144, |
| | ]) |
| | part_indices["left_hand"] = np.array([ |
| | 20, |
| | 25, |
| | 26, |
| | 27, |
| | 28, |
| | 29, |
| | 30, |
| | 31, |
| | 32, |
| | 33, |
| | 34, |
| | 35, |
| | 36, |
| | 37, |
| | 38, |
| | 39, |
| | 128, |
| | 129, |
| | 130, |
| | 131, |
| | 133, |
| | ]) |
| | part_indices["right_hand"] = np.array([ |
| | 21, |
| | 40, |
| | 41, |
| | 42, |
| | 43, |
| | 44, |
| | 45, |
| | 46, |
| | 47, |
| | 48, |
| | 49, |
| | 50, |
| | 51, |
| | 52, |
| | 53, |
| | 54, |
| | 139, |
| | 140, |
| | 141, |
| | 142, |
| | 144, |
| | ]) |
| | |
| | head_kin_chain = [15, 12, 9, 6, 3, 0] |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | class SMPLX(nn.Module): |
| | """ |
| | Given smplx parameters, this class generates a differentiable SMPLX function |
| | which outputs a mesh and 3D joints |
| | """ |
| |
|
| | def __init__(self, config): |
| | super(SMPLX, self).__init__() |
| | |
| | ss = np.load(config.smplx_model_path, allow_pickle=True) |
| | smplx_model = Struct(**ss) |
| |
|
| | self.dtype = torch.float32 |
| | self.register_buffer( |
| | "faces_tensor", |
| | to_tensor(to_np(smplx_model.f, dtype=np.int64), dtype=torch.long), |
| | ) |
| | |
| | self.register_buffer( |
| | "v_template", |
| | to_tensor(to_np(smplx_model.v_template), dtype=self.dtype)) |
| | |
| | |
| | shapedirs = to_tensor(to_np(smplx_model.shapedirs), dtype=self.dtype) |
| | shapedirs = torch.cat( |
| | [ |
| | shapedirs[:, :, :config.n_shape], |
| | shapedirs[:, :, 300:300 + config.n_exp], |
| | ], |
| | 2, |
| | ) |
| | self.register_buffer("shapedirs", shapedirs) |
| | |
| | num_pose_basis = smplx_model.posedirs.shape[-1] |
| | posedirs = np.reshape(smplx_model.posedirs, [-1, num_pose_basis]).T |
| | self.register_buffer("posedirs", |
| | to_tensor(to_np(posedirs), dtype=self.dtype)) |
| | self.register_buffer( |
| | "J_regressor", |
| | to_tensor(to_np(smplx_model.J_regressor), dtype=self.dtype)) |
| | parents = to_tensor(to_np(smplx_model.kintree_table[0])).long() |
| | parents[0] = -1 |
| | self.register_buffer("parents", parents) |
| | self.register_buffer( |
| | "lbs_weights", |
| | to_tensor(to_np(smplx_model.weights), dtype=self.dtype)) |
| | |
| | self.register_buffer( |
| | "lmk_faces_idx", |
| | torch.tensor(smplx_model.lmk_faces_idx, dtype=torch.long)) |
| | self.register_buffer( |
| | "lmk_bary_coords", |
| | torch.tensor(smplx_model.lmk_bary_coords, dtype=self.dtype), |
| | ) |
| | self.register_buffer( |
| | "dynamic_lmk_faces_idx", |
| | torch.tensor(smplx_model.dynamic_lmk_faces_idx, dtype=torch.long), |
| | ) |
| | self.register_buffer( |
| | "dynamic_lmk_bary_coords", |
| | torch.tensor(smplx_model.dynamic_lmk_bary_coords, |
| | dtype=self.dtype), |
| | ) |
| | |
| | self.register_buffer("head_kin_chain", |
| | torch.tensor(head_kin_chain, dtype=torch.long)) |
| |
|
| | |
| | |
| | self.register_buffer( |
| | "shape_params", |
| | nn.Parameter(torch.zeros([1, config.n_shape], dtype=self.dtype), |
| | requires_grad=False), |
| | ) |
| | self.register_buffer( |
| | "expression_params", |
| | nn.Parameter(torch.zeros([1, config.n_exp], dtype=self.dtype), |
| | requires_grad=False), |
| | ) |
| | |
| | self.register_buffer( |
| | "global_pose", |
| | nn.Parameter( |
| | torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(1, 1, 1), |
| | requires_grad=False, |
| | ), |
| | ) |
| | self.register_buffer( |
| | "head_pose", |
| | nn.Parameter( |
| | torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(1, 1, 1), |
| | requires_grad=False, |
| | ), |
| | ) |
| | self.register_buffer( |
| | "neck_pose", |
| | nn.Parameter( |
| | torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(1, 1, 1), |
| | requires_grad=False, |
| | ), |
| | ) |
| | self.register_buffer( |
| | "jaw_pose", |
| | nn.Parameter( |
| | torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(1, 1, 1), |
| | requires_grad=False, |
| | ), |
| | ) |
| | self.register_buffer( |
| | "eye_pose", |
| | nn.Parameter( |
| | torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(2, 1, 1), |
| | requires_grad=False, |
| | ), |
| | ) |
| | self.register_buffer( |
| | "body_pose", |
| | nn.Parameter( |
| | torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(21, 1, 1), |
| | requires_grad=False, |
| | ), |
| | ) |
| | self.register_buffer( |
| | "left_hand_pose", |
| | nn.Parameter( |
| | torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(15, 1, 1), |
| | requires_grad=False, |
| | ), |
| | ) |
| | self.register_buffer( |
| | "right_hand_pose", |
| | nn.Parameter( |
| | torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(15, 1, 1), |
| | requires_grad=False, |
| | ), |
| | ) |
| |
|
| | if config.extra_joint_path: |
| | self.extra_joint_selector = JointsFromVerticesSelector( |
| | fname=config.extra_joint_path) |
| | self.use_joint_regressor = True |
| | self.keypoint_names = SMPLX_names |
| | if self.use_joint_regressor: |
| | with open(config.j14_regressor_path, "rb") as f: |
| | j14_regressor = pickle.load(f, encoding="latin1") |
| | source = [] |
| | target = [] |
| | for idx, name in enumerate(self.keypoint_names): |
| | if name in J14_NAMES: |
| | source.append(idx) |
| | target.append(J14_NAMES.index(name)) |
| | source = np.asarray(source) |
| | target = np.asarray(target) |
| | self.register_buffer("source_idxs", torch.from_numpy(source)) |
| | self.register_buffer("target_idxs", torch.from_numpy(target)) |
| | joint_regressor = torch.from_numpy(j14_regressor).to( |
| | dtype=torch.float32) |
| | self.register_buffer("extra_joint_regressor", joint_regressor) |
| | self.part_indices = part_indices |
| |
|
| | def forward( |
| | self, |
| | shape_params=None, |
| | expression_params=None, |
| | global_pose=None, |
| | body_pose=None, |
| | jaw_pose=None, |
| | eye_pose=None, |
| | left_hand_pose=None, |
| | right_hand_pose=None, |
| | ): |
| | """ |
| | Args: |
| | shape_params: [N, number of shape parameters] |
| | expression_params: [N, number of expression parameters] |
| | global_pose: pelvis pose, [N, 1, 3, 3] |
| | body_pose: [N, 21, 3, 3] |
| | jaw_pose: [N, 1, 3, 3] |
| | eye_pose: [N, 2, 3, 3] |
| | left_hand_pose: [N, 15, 3, 3] |
| | right_hand_pose: [N, 15, 3, 3] |
| | Returns: |
| | vertices: [N, number of vertices, 3] |
| | landmarks: [N, number of landmarks (68 face keypoints), 3] |
| | joints: [N, number of smplx joints (145), 3] |
| | """ |
| | if shape_params is None: |
| | batch_size = global_pose.shape[0] |
| | shape_params = self.shape_params.expand(batch_size, -1) |
| | else: |
| | batch_size = shape_params.shape[0] |
| | if expression_params is None: |
| | expression_params = self.expression_params.expand(batch_size, -1) |
| | if global_pose is None: |
| | global_pose = self.global_pose.unsqueeze(0).expand( |
| | batch_size, -1, -1, -1) |
| | if body_pose is None: |
| | body_pose = self.body_pose.unsqueeze(0).expand( |
| | batch_size, -1, -1, -1) |
| | if jaw_pose is None: |
| | jaw_pose = self.jaw_pose.unsqueeze(0).expand( |
| | batch_size, -1, -1, -1) |
| | if eye_pose is None: |
| | eye_pose = self.eye_pose.unsqueeze(0).expand( |
| | batch_size, -1, -1, -1) |
| | if left_hand_pose is None: |
| | left_hand_pose = self.left_hand_pose.unsqueeze(0).expand( |
| | batch_size, -1, -1, -1) |
| | if right_hand_pose is None: |
| | right_hand_pose = self.right_hand_pose.unsqueeze(0).expand( |
| | batch_size, -1, -1, -1) |
| |
|
| | shape_components = torch.cat([shape_params, expression_params], dim=1) |
| | full_pose = torch.cat( |
| | [ |
| | global_pose, |
| | body_pose, |
| | jaw_pose, |
| | eye_pose, |
| | left_hand_pose, |
| | right_hand_pose, |
| | ], |
| | dim=1, |
| | ) |
| | template_vertices = self.v_template.unsqueeze(0).expand( |
| | batch_size, -1, -1) |
| | |
| | vertices, joints = lbs( |
| | shape_components, |
| | full_pose, |
| | template_vertices, |
| | self.shapedirs, |
| | self.posedirs, |
| | self.J_regressor, |
| | self.parents, |
| | self.lbs_weights, |
| | dtype=self.dtype, |
| | pose2rot=False, |
| | ) |
| | |
| | lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).expand( |
| | batch_size, -1) |
| | lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).expand( |
| | batch_size, -1, -1) |
| | dyn_lmk_faces_idx, dyn_lmk_bary_coords = find_dynamic_lmk_idx_and_bcoords( |
| | vertices, |
| | full_pose, |
| | self.dynamic_lmk_faces_idx, |
| | self.dynamic_lmk_bary_coords, |
| | self.head_kin_chain, |
| | ) |
| | lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1) |
| | lmk_bary_coords = torch.cat([lmk_bary_coords, dyn_lmk_bary_coords], 1) |
| | landmarks = vertices2landmarks(vertices, self.faces_tensor, |
| | lmk_faces_idx, lmk_bary_coords) |
| |
|
| | final_joint_set = [joints, landmarks] |
| | if hasattr(self, "extra_joint_selector"): |
| | |
| | extra_joints = self.extra_joint_selector(vertices, |
| | self.faces_tensor) |
| | final_joint_set.append(extra_joints) |
| | |
| | joints = torch.cat(final_joint_set, dim=1) |
| | |
| | |
| | |
| | |
| | |
| | |
| | return vertices, landmarks, joints |
| |
|
| | def pose_abs2rel(self, global_pose, body_pose, abs_joint="head"): |
| | """change absolute pose to relative pose |
| | Basic knowledge for SMPLX kinematic tree: |
| | absolute pose = parent pose * relative pose |
| | Here, pose must be represented as rotation matrix (batch_sizexnx3x3) |
| | """ |
| | if abs_joint == "head": |
| | |
| | kin_chain = [15, 12, 9, 6, 3, 0] |
| | elif abs_joint == "neck": |
| | |
| | kin_chain = [12, 9, 6, 3, 0] |
| | elif abs_joint == "right_wrist": |
| | |
| | |
| | kin_chain = [21, 19, 17, 14, 9, 6, 3, 0] |
| | elif abs_joint == "left_wrist": |
| | |
| | |
| | kin_chain = [20, 18, 16, 13, 9, 6, 3, 0] |
| | else: |
| | raise NotImplementedError( |
| | f"pose_abs2rel does not support: {abs_joint}") |
| |
|
| | batch_size = global_pose.shape[0] |
| | dtype = global_pose.dtype |
| | device = global_pose.device |
| | full_pose = torch.cat([global_pose, body_pose], dim=1) |
| | rel_rot_mat = (torch.eye(3, device=device, |
| | dtype=dtype).unsqueeze_(dim=0).repeat( |
| | batch_size, 1, 1)) |
| | for idx in kin_chain[1:]: |
| | rel_rot_mat = torch.bmm(full_pose[:, idx], rel_rot_mat) |
| |
|
| | |
| | abs_parent_pose = rel_rot_mat.detach() |
| | |
| | abs_joint_pose = body_pose[:, kin_chain[0] - 1] |
| | |
| | rel_joint_pose = torch.matmul( |
| | abs_parent_pose.reshape(-1, 3, 3).transpose(1, 2), |
| | abs_joint_pose.reshape(-1, 3, 3), |
| | ) |
| | |
| | body_pose[:, kin_chain[0] - 1, :, :] = rel_joint_pose |
| | return body_pose |
| |
|
| | def pose_rel2abs(self, global_pose, body_pose, abs_joint="head"): |
| | """change relative pose to absolute pose |
| | Basic knowledge for SMPLX kinematic tree: |
| | absolute pose = parent pose * relative pose |
| | Here, pose must be represented as rotation matrix (batch_sizexnx3x3) |
| | """ |
| | full_pose = torch.cat([global_pose, body_pose], dim=1) |
| |
|
| | if abs_joint == "head": |
| | |
| | kin_chain = [15, 12, 9, 6, 3, 0] |
| | elif abs_joint == "neck": |
| | |
| | kin_chain = [12, 9, 6, 3, 0] |
| | elif abs_joint == "right_wrist": |
| | |
| | |
| | kin_chain = [21, 19, 17, 14, 9, 6, 3, 0] |
| | elif abs_joint == "left_wrist": |
| | |
| | |
| | kin_chain = [20, 18, 16, 13, 9, 6, 3, 0] |
| | else: |
| | raise NotImplementedError( |
| | f"pose_rel2abs does not support: {abs_joint}") |
| | rel_rot_mat = torch.eye(3, |
| | device=full_pose.device, |
| | dtype=full_pose.dtype).unsqueeze_(dim=0) |
| | for idx in kin_chain: |
| | rel_rot_mat = torch.matmul(full_pose[:, idx], rel_rot_mat) |
| | abs_pose = rel_rot_mat[:, None, :, :] |
| | return abs_pose |
| |
|