| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from __future__ import absolute_import |
| | from __future__ import print_function |
| | from __future__ import division |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| |
|
| |
|
| | def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32): |
| | """Calculates the rotation matrices for a batch of rotation vectors |
| | Parameters |
| | ---------- |
| | rot_vecs: torch.tensor Nx3 |
| | array of N axis-angle vectors |
| | Returns |
| | ------- |
| | R: torch.tensor Nx3x3 |
| | The rotation matrices for the given axis-angle parameters |
| | """ |
| |
|
| | batch_size = rot_vecs.shape[0] |
| | device = rot_vecs.device |
| |
|
| | angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True) |
| | rot_dir = rot_vecs / angle |
| |
|
| | cos = torch.unsqueeze(torch.cos(angle), dim=1) |
| | sin = torch.unsqueeze(torch.sin(angle), dim=1) |
| |
|
| | |
| | rx, ry, rz = torch.split(rot_dir, 1, dim=1) |
| | K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device) |
| |
|
| | zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device) |
| | K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1).view( |
| | (batch_size, 3, 3) |
| | ) |
| |
|
| | ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0) |
| | rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K) |
| | return rot_mat |
| |
|
| |
|
| | def vertices2landmarks(vertices, faces, lmk_faces_idx, lmk_bary_coords): |
| | """Calculates landmarks by barycentric interpolation |
| | |
| | Parameters |
| | ---------- |
| | vertices: torch.tensor BxVx3, dtype = torch.float32 |
| | The tensor of input vertices |
| | faces: torch.tensor Fx3, dtype = torch.long |
| | The faces of the mesh |
| | lmk_faces_idx: torch.tensor L, dtype = torch.long |
| | The tensor with the indices of the faces used to calculate the |
| | landmarks. |
| | lmk_bary_coords: torch.tensor Lx3, dtype = torch.float32 |
| | The tensor of barycentric coordinates that are used to interpolate |
| | the landmarks |
| | |
| | Returns |
| | ------- |
| | landmarks: torch.tensor BxLx3, dtype = torch.float32 |
| | The coordinates of the landmarks for each mesh in the batch |
| | """ |
| | |
| | |
| | batch_size, num_verts = vertices.shape[:2] |
| | device = vertices.device |
| |
|
| | lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).view( |
| | batch_size, -1, 3 |
| | ) |
| |
|
| | lmk_faces += ( |
| | torch.arange(batch_size, dtype=torch.long, device=device).view(-1, 1, 1) |
| | * num_verts |
| | ) |
| |
|
| | lmk_vertices = vertices.view(-1, 3)[lmk_faces].view(batch_size, -1, 3, 3) |
| |
|
| | landmarks = torch.einsum("blfi,blf->bli", [lmk_vertices, lmk_bary_coords]) |
| | return landmarks |
| |
|
| |
|
| | def lbs( |
| | pose, |
| | v_shaped, |
| | posedirs, |
| | J_regressor, |
| | parents, |
| | lbs_weights, |
| | pose2rot=True, |
| | dtype=torch.float32, |
| | ): |
| | """Performs Linear Blend Skinning with the given shape and pose parameters |
| | |
| | Parameters |
| | ---------- |
| | betas : torch.tensor BxNB |
| | The tensor of shape parameters |
| | pose : torch.tensor Bx(J + 1) * 3 |
| | The pose parameters in axis-angle format |
| | v_template: torch.tensor BxVx3 |
| | The template mesh that will be deformed |
| | shapedirs : torch.tensor 1xNB |
| | The tensor of PCA shape displacements |
| | posedirs : torch.tensor Px(V * 3) |
| | The pose PCA coefficients |
| | J_regressor : torch.tensor JxV |
| | The regressor array that is used to calculate the joints from |
| | the position of the vertices |
| | parents: torch.tensor J |
| | The array that describes the kinematic tree for the model |
| | lbs_weights: torch.tensor N x V x (J + 1) |
| | The linear blend skinning weights that represent how much the |
| | rotation matrix of each part affects each vertex |
| | pose2rot: bool, optional |
| | Flag on whether to convert the input pose tensor to rotation |
| | matrices. The default value is True. If False, then the pose tensor |
| | should already contain rotation matrices and have a size of |
| | Bx(J + 1)x9 |
| | dtype: torch.dtype, optional |
| | |
| | Returns |
| | ------- |
| | verts: torch.tensor BxVx3 |
| | The vertices of the mesh after applying the shape and pose |
| | displacements. |
| | joints: torch.tensor BxJx3 |
| | The joints of the model |
| | """ |
| |
|
| | batch_size = pose.shape[0] |
| | device = pose.device |
| |
|
| | |
| | |
| | J = vertices2joints(J_regressor, v_shaped) |
| |
|
| | |
| | |
| | ident = torch.eye(3, dtype=dtype, device=device) |
| | if pose2rot: |
| | rot_mats = batch_rodrigues(pose.view(-1, 3), dtype=dtype).view( |
| | [batch_size, -1, 3, 3] |
| | ) |
| |
|
| | pose_feature = (rot_mats[:, 1:, :, :] - ident).view([batch_size, -1]) |
| | |
| | pose_offsets = torch.matmul(pose_feature, posedirs).view(batch_size, -1, 3) |
| | else: |
| | pose_feature = pose[:, 1:].view(batch_size, -1, 3, 3) - ident |
| | rot_mats = pose.view(batch_size, -1, 3, 3) |
| |
|
| | pose_offsets = torch.matmul(pose_feature.view(batch_size, -1), posedirs).view( |
| | batch_size, -1, 3 |
| | ) |
| |
|
| | v_posed = pose_offsets + v_shaped |
| |
|
| | |
| | J_transformed, A = batch_rigid_transform(rot_mats, J, parents, dtype=dtype) |
| |
|
| | |
| | |
| | W = lbs_weights.unsqueeze(dim=0).expand([batch_size, -1, -1]) |
| | |
| | num_joints = J_regressor.shape[0] |
| | T = torch.matmul(W, A.view(batch_size, num_joints, 16)).view(batch_size, -1, 4, 4) |
| |
|
| | homogen_coord = torch.ones( |
| | [batch_size, v_posed.shape[1], 1], dtype=dtype, device=device |
| | ) |
| | v_posed_homo = torch.cat([v_posed, homogen_coord], dim=2) |
| | v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, dim=-1)) |
| |
|
| | verts = v_homo[:, :, :3, 0] |
| |
|
| | return verts, J_transformed, A[:, 1] |
| |
|
| |
|
| | def vertices2joints(J_regressor, vertices): |
| | """Calculates the 3D joint locations from the vertices |
| | |
| | Parameters |
| | ---------- |
| | J_regressor : torch.tensor JxV |
| | The regressor array that is used to calculate the joints from the |
| | position of the vertices |
| | vertices : torch.tensor BxVx3 |
| | The tensor of mesh vertices |
| | |
| | Returns |
| | ------- |
| | torch.tensor BxJx3 |
| | The location of the joints |
| | """ |
| |
|
| | return torch.einsum("bik,ji->bjk", [vertices, J_regressor]) |
| |
|
| |
|
| | def blend_shapes(betas, shape_disps): |
| | """Calculates the per vertex displacement due to the blend shapes |
| | |
| | |
| | Parameters |
| | ---------- |
| | betas : torch.tensor Bx(num_betas) |
| | Blend shape coefficients |
| | shape_disps: torch.tensor Vx3x(num_betas) |
| | Blend shapes |
| | |
| | Returns |
| | ------- |
| | torch.tensor BxVx3 |
| | The per-vertex displacement due to shape deformation |
| | """ |
| |
|
| | |
| | |
| | |
| | blend_shape = torch.einsum("bl,mkl->bmk", [betas, shape_disps]) |
| | return blend_shape |
| |
|
| |
|
| | def transform_mat(R, t): |
| | """Creates a batch of transformation matrices |
| | Args: |
| | - R: Bx3x3 array of a batch of rotation matrices |
| | - t: Bx3x1 array of a batch of translation vectors |
| | Returns: |
| | - T: Bx4x4 Transformation matrix |
| | """ |
| | |
| | return torch.cat([F.pad(R, [0, 0, 0, 1]), F.pad(t, [0, 0, 0, 1], value=1)], dim=2) |
| |
|
| |
|
| | def batch_rigid_transform(rot_mats, joints, parents, dtype=torch.float32): |
| | """ |
| | Applies a batch of rigid transformations to the joints |
| | |
| | Parameters |
| | ---------- |
| | rot_mats : torch.tensor BxNx3x3 |
| | Tensor of rotation matrices |
| | joints : torch.tensor BxNx3 |
| | Locations of joints |
| | parents : torch.tensor BxN |
| | The kinematic tree of each object |
| | dtype : torch.dtype, optional: |
| | The data type of the created tensors, the default is torch.float32 |
| | |
| | Returns |
| | ------- |
| | posed_joints : torch.tensor BxNx3 |
| | The locations of the joints after applying the pose rotations |
| | rel_transforms : torch.tensor BxNx4x4 |
| | The relative (with respect to the root joint) rigid transformations |
| | for all the joints |
| | """ |
| |
|
| | joints = torch.unsqueeze(joints, dim=-1) |
| |
|
| | rel_joints = joints.clone().contiguous() |
| | rel_joints[:, 1:] = rel_joints[:, 1:] - joints[:, parents[1:]] |
| |
|
| | transforms_mat = transform_mat(rot_mats.view(-1, 3, 3), rel_joints.view(-1, 3, 1)) |
| | transforms_mat = transforms_mat.view(-1, joints.shape[1], 4, 4) |
| |
|
| | transform_chain = [transforms_mat[:, 0]] |
| | for i in range(1, parents.shape[0]): |
| | |
| | |
| | curr_res = torch.matmul(transform_chain[parents[i]], transforms_mat[:, i]) |
| | transform_chain.append(curr_res) |
| |
|
| | transforms = torch.stack(transform_chain, dim=1) |
| |
|
| | |
| | posed_joints = transforms[:, :, :3, 3] |
| |
|
| | joints_homogen = F.pad(joints, [0, 0, 0, 1]) |
| |
|
| | rel_transforms = transforms - F.pad( |
| | torch.matmul(transforms, joints_homogen), [3, 0, 0, 0, 0, 0, 0, 0] |
| | ) |
| |
|
| | return posed_joints, rel_transforms |
| |
|