| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| import pickle |
|
|
| import loguru |
| import numpy as np |
| import torch |
| import torch.nn as nn |
|
|
| from .lbs import lbs, batch_rodrigues, vertices2landmarks, rot_mat_to_euler |
|
|
|
|
| def to_tensor(array, dtype=torch.float32): |
| if 'torch.tensor' not in str(type(array)): |
| return torch.tensor(array, dtype=dtype) |
|
|
|
|
| def to_np(array, dtype=np.float32): |
| if 'scipy.sparse' in str(type(array)): |
| array = array.todense() |
| return np.array(array, dtype=dtype) |
|
|
|
|
| class Struct(object): |
| def __init__(self, **kwargs): |
| for key, val in kwargs.items(): |
| setattr(self, key, val) |
|
|
|
|
| class FLAME(nn.Module): |
| """ |
| borrowed from https://github.com/soubhiksanyal/FLAME_PyTorch/blob/master/FLAME.py |
| Given flame parameters this class generates a differentiable FLAME function |
| which outputs the a mesh and 2D/3D facial landmarks |
| """ |
|
|
| def __init__(self, config, optimize_basis=False): |
| super(FLAME, self).__init__() |
| loguru.logger.info("[FLAME] creating the FLAME Decoder") |
| with open(config.flame_model_path, 'rb') as f: |
| ss = pickle.load(f, encoding='latin1') |
| flame_model = Struct(**ss) |
|
|
| self.optimize_basis = optimize_basis |
| self.cfg = config |
| self.dtype = torch.float32 |
| self.register_buffer('faces_tensor', to_tensor(to_np(flame_model.f, dtype=np.int64), dtype=torch.long)) |
| |
| self.register_buffer('v_template', to_tensor(to_np(flame_model.v_template), dtype=self.dtype)) |
| self.n_vertices = self.v_template.shape[0] |
| |
| shapedirs = to_tensor(to_np(flame_model.shapedirs), dtype=self.dtype) |
| shapedirs = torch.cat([shapedirs[:, :, :config.n_shape], shapedirs[:, :, 300:]], 2) |
|
|
| if optimize_basis: |
| self.register_parameter('shapedirs', torch.nn.Parameter(shapedirs)) |
| else: |
| self.register_buffer('shapedirs', shapedirs) |
|
|
| self.n_shape = config.n_shape |
| |
| num_pose_basis = flame_model.posedirs.shape[-1] |
| posedirs = np.reshape(flame_model.posedirs, [-1, num_pose_basis]).T |
| self.register_buffer('posedirs', to_tensor(to_np(posedirs), dtype=self.dtype)) |
| |
| self.register_buffer('J_regressor', to_tensor(to_np(flame_model.J_regressor), dtype=self.dtype)) |
| parents = to_tensor(to_np(flame_model.kintree_table[0])).long(); |
| parents[0] = -1 |
| self.register_buffer('parents', parents) |
| self.register_buffer('lbs_weights', to_tensor(to_np(flame_model.weights), dtype=self.dtype)) |
|
|
| |
| default_eyball_pose = torch.zeros([1, 6], dtype=self.dtype, requires_grad=False) |
| self.register_parameter('eye_pose', nn.Parameter(default_eyball_pose, requires_grad=False)) |
| default_neck_pose = torch.zeros([1, 3], dtype=self.dtype, requires_grad=False) |
| self.register_parameter('neck_pose', nn.Parameter(default_neck_pose, requires_grad=False)) |
|
|
| |
| lmk_embeddings = np.load(config.flame_lmk_embedding_path, allow_pickle=True, encoding='latin1') |
| lmk_embeddings = lmk_embeddings[()] |
| self.register_buffer('lmk_faces_idx', torch.from_numpy(lmk_embeddings['static_lmk_faces_idx']).long()) |
| self.register_buffer('lmk_bary_coords', torch.from_numpy(lmk_embeddings['static_lmk_bary_coords']).to(self.dtype)) |
| self.register_buffer('dynamic_lmk_faces_idx', lmk_embeddings['dynamic_lmk_faces_idx'].long()) |
| self.register_buffer('dynamic_lmk_bary_coords', lmk_embeddings['dynamic_lmk_bary_coords'].to(self.dtype)) |
| self.register_buffer('full_lmk_faces_idx', torch.from_numpy(lmk_embeddings['full_lmk_faces_idx']).long()) |
| self.register_buffer('full_lmk_bary_coords', torch.from_numpy(lmk_embeddings['full_lmk_bary_coords']).to(self.dtype)) |
|
|
| neck_kin_chain = []; |
| NECK_IDX = 1 |
| curr_idx = torch.tensor(NECK_IDX, dtype=torch.long) |
| while curr_idx != -1: |
| neck_kin_chain.append(curr_idx) |
| curr_idx = self.parents[curr_idx] |
| self.register_buffer('neck_kin_chain', torch.stack(neck_kin_chain)) |
|
|
| def _find_dynamic_lmk_idx_and_bcoords(self, pose, dynamic_lmk_faces_idx, |
| dynamic_lmk_b_coords, |
| neck_kin_chain, dtype=torch.float32): |
| """ |
| Selects the face contour depending on the reletive position of the head |
| Input: |
| vertices: N X num_of_vertices X 3 |
| pose: N X full pose |
| dynamic_lmk_faces_idx: The list of contour face indexes |
| dynamic_lmk_b_coords: The list of contour barycentric weights |
| neck_kin_chain: The tree to consider for the relative rotation |
| dtype: Data type |
| return: |
| The contour face indexes and the corresponding barycentric weights |
| """ |
|
|
| batch_size = pose.shape[0] |
|
|
| aa_pose = torch.index_select(pose.view(batch_size, -1, 3), 1, |
| neck_kin_chain) |
| rot_mats = batch_rodrigues( |
| aa_pose.view(-1, 3), dtype=dtype).view(batch_size, -1, 3, 3) |
|
|
| rel_rot_mat = torch.eye(3, device=pose.device, |
| dtype=dtype).unsqueeze_(dim=0).expand(batch_size, -1, -1) |
| for idx in range(len(neck_kin_chain)): |
| rel_rot_mat = torch.bmm(rot_mats[:, idx], rel_rot_mat) |
|
|
| y_rot_angle = torch.round( |
| torch.clamp(rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi, |
| max=39)).to(dtype=torch.long) |
|
|
| neg_mask = y_rot_angle.lt(0).to(dtype=torch.long) |
| mask = y_rot_angle.lt(-39).to(dtype=torch.long) |
| neg_vals = mask * 78 + (1 - mask) * (39 - y_rot_angle) |
| y_rot_angle = (neg_mask * neg_vals + |
| (1 - neg_mask) * y_rot_angle) |
|
|
| dyn_lmk_faces_idx = torch.index_select(dynamic_lmk_faces_idx, |
| 0, y_rot_angle) |
| dyn_lmk_b_coords = torch.index_select(dynamic_lmk_b_coords, |
| 0, y_rot_angle) |
| return dyn_lmk_faces_idx, dyn_lmk_b_coords |
|
|
| def _vertices2landmarks(self, vertices, faces, lmk_faces_idx, lmk_bary_coords): |
| """ |
| Calculates landmarks by barycentric interpolation |
| Input: |
| vertices: torch.tensor NxVx3, dtype = torch.float32 |
| The tensor of input vertices |
| faces: torch.tensor (N*F)x3, dtype = torch.long |
| The faces of the mesh |
| lmk_faces_idx: torch.tensor N X L, dtype = torch.long |
| The tensor with the indices of the faces used to calculate the |
| landmarks. |
| lmk_bary_coords: torch.tensor N X L X 3, dtype = torch.float32 |
| The tensor of barycentric coordinates that are used to interpolate |
| the landmarks |
| |
| Returns: |
| landmarks: torch.tensor NxLx3, dtype = torch.float32 |
| The coordinates of the landmarks for each mesh in the batch |
| """ |
| |
| |
| batch_size, num_verts = vertices.shape[:dd2] |
| lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).view( |
| 1, -1, 3).view(batch_size, lmk_faces_idx.shape[1], -1) |
|
|
| lmk_faces += torch.arange(batch_size, dtype=torch.long).view(-1, 1, 1).to( |
| device=vertices.device) * num_verts |
|
|
| lmk_vertices = vertices.view(-1, 3)[lmk_faces] |
| landmarks = torch.einsum('blfi,blf->bli', [lmk_vertices, lmk_bary_coords]) |
| return landmarks |
|
|
| |
| def compute_landmarks(self, vertices): |
| landmarks3d = vertices2landmarks(vertices, self.faces_tensor, |
| self.full_lmk_faces_idx.repeat(vertices.shape[0], 1), |
| self.full_lmk_bary_coords.repeat(vertices.shape[0], 1, 1)) |
| return landmarks3d |
|
|
| def seletec_3d68(self, vertices): |
| landmarks3d = vertices2landmarks(vertices, self.faces_tensor, |
| self.full_lmk_faces_idx.repeat(vertices.shape[0], 1), |
| self.full_lmk_bary_coords.repeat(vertices.shape[0], 1, 1)) |
| return landmarks3d |
|
|
| def project_to_shape_basis(self, shape_vector, shape_as_offset=False): |
| batch_size = shape_vector.shape[0] |
| n_vertices = self.v_template.shape[0] |
| n_eigenvectors = self.n_shape |
| |
| if shape_as_offset: |
| diff = shape_vector |
| else: |
| diff = shape_vector - self.v_template |
| return torch.matmul(diff.reshape(batch_size, -1), self.shapedirs[:, :, :n_eigenvectors].reshape(3 * n_vertices, n_eigenvectors)) |
|
|
| def compute_distance_to_basis(self, shape_vector, shape_as_offset=False): |
| batch_size = shape_vector.shape[0] |
| n_vertices = self.v_template.shape[0] |
| n_eigenvectors = self.n_shape |
|
|
| |
| |
| |
| |
| |
|
|
| |
| if shape_as_offset: |
| diff = shape_vector |
| else: |
| diff = shape_vector - self.v_template |
| shape_params = torch.matmul(diff.reshape(batch_size, -1), self.shapedirs[:, :, :n_eigenvectors].reshape(3 * n_vertices, n_eigenvectors)) |
| distance = diff - torch.matmul(shape_params, self.shapedirs[:, :, :n_eigenvectors].reshape(n_vertices * 3, n_eigenvectors).t()).reshape(batch_size, n_vertices, 3) |
| return distance |
|
|
| def get_std(self): |
| n_eigenvectors = self.cfg.n_shape |
| basis = self.shapedirs[:, :, :n_eigenvectors] |
| std = torch.norm(basis.reshape(-1, n_eigenvectors), dim=0) |
|
|
| return std |
|
|
| def compute_closest_shape(self, shape_vector): |
| B = shape_vector.shape[0] |
| N = self.v_template.shape[0] |
| n_eigenvectors = self.cfg.n_shape |
|
|
| basis = self.shapedirs[:, :, :n_eigenvectors] |
| diff = (shape_vector - self.v_template).reshape(B, -1) |
| std = torch.norm(basis.reshape(-1, n_eigenvectors), dim=0) |
| inv = 1.0 / std.square() |
| params = inv * torch.matmul(diff, basis.reshape(3 * N, n_eigenvectors)) |
| |
|
|
| return self.v_template + torch.matmul(params, basis.reshape(N * 3, n_eigenvectors).T).reshape(B, N, 3), params |
|
|
| def forward(self, shape_params=None, expression_params=None, pose_params=None, eye_pose_params=None, neck_pose_params=None, shape_basis_delta=None): |
| """ |
| Input: |
| shape_params: N X number of shape parameters |
| expression_params: N X number of expression parameters |
| pose_params: N X number of pose parameters (6) |
| return:d |
| vertices: N X V X 3 |
| landmarks: N X number of landmarks X 3 |
| """ |
| batch_size = shape_params.shape[0] |
| if pose_params is None: |
| pose_params = self.eye_pose.expand(batch_size, -1) |
| if eye_pose_params is None: |
| eye_pose_params = self.eye_pose.expand(batch_size, -1) |
| if neck_pose_params is None: |
| neck_pose_params = self.neck_pose.expand(batch_size, -1) |
| if expression_params is None: |
| expression_params = torch.zeros([1, 100], dtype=self.dtype, requires_grad=False, device=self.neck_pose.device).expand(batch_size, -1) |
|
|
| betas = torch.cat([shape_params, expression_params], dim=1) |
| full_pose = torch.cat([pose_params[:, :3], neck_pose_params, pose_params[:, 3:], eye_pose_params], dim=1) |
| template_vertices = self.v_template.unsqueeze(0).expand(batch_size, -1, -1) |
|
|
| vertices, _ = lbs(betas, full_pose, template_vertices, |
| self.shapedirs, self.posedirs, |
| self.J_regressor, self.parents, |
| self.lbs_weights, dtype=self.dtype) |
|
|
| lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).expand(batch_size, -1) |
| lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).expand(batch_size, -1, -1) |
|
|
| dyn_lmk_faces_idx, dyn_lmk_bary_coords = self._find_dynamic_lmk_idx_and_bcoords( |
| full_pose, self.dynamic_lmk_faces_idx, |
| self.dynamic_lmk_bary_coords, |
| self.neck_kin_chain, dtype=self.dtype) |
| lmk_faces_idx = torch.cat([dyn_lmk_faces_idx, lmk_faces_idx], 1) |
| lmk_bary_coords = torch.cat([dyn_lmk_bary_coords, lmk_bary_coords], 1) |
|
|
| landmarks2d = vertices2landmarks(vertices, self.faces_tensor, |
| lmk_faces_idx, |
| lmk_bary_coords) |
| bz = vertices.shape[0] |
| landmarks3d = vertices2landmarks(vertices, self.faces_tensor, |
| self.full_lmk_faces_idx.repeat(bz, 1), |
| self.full_lmk_bary_coords.repeat(bz, 1, 1)) |
| return vertices, landmarks2d, landmarks3d |
|
|