Spaces:
Build error
Build error
| # -*- coding: utf-8 -*- | |
| # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is | |
| # holder of all proprietary rights on this computer program. | |
| # You can only use this computer program if you have closed | |
| # a license agreement with MPG or you get the right to use the computer | |
| # program from someone who is authorized to grant you that right. | |
| # Any use of the computer program without a valid license is prohibited and | |
| # liable to prosecution. | |
| # | |
| # Copyright©2023 Max-Planck-Gesellschaft zur Förderung | |
| # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute | |
| # for Intelligent Systems. All rights reserved. | |
| # | |
| # Contact: mica@tue.mpg.de | |
| import os | |
| import pickle | |
| from huggingface_hub import hf_hub_download | |
| from pixel3dmm import env_paths | |
| import numpy as np | |
| # Modified from smplx code for FLAME | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from skimage.io import imread | |
| from pixel3dmm.utils.utils_3d import rotation_6d_to_matrix, matrix_to_rotation_6d | |
| from pixel3dmm.tracking.flame.lbs import lbs | |
| from pixel3dmm import env_paths | |
| I = matrix_to_rotation_6d(torch.eye(3)[None].cuda()) | |
| def to_tensor(array, dtype=torch.float32): | |
| if 'torch.tensor' not in str(type(array)): | |
| return torch.tensor(array, dtype=dtype) | |
| def to_np(array, dtype=np.float32): | |
| if 'scipy.sparse' in str(type(array)): | |
| array = array.todense() | |
| return np.array(array, dtype=dtype) | |
| class Struct(object): | |
| def __init__(self, **kwargs): | |
| for key, val in kwargs.items(): | |
| setattr(self, key, val) | |
| def rot_mat_to_euler(rot_mats): | |
| # Calculates rotation matrix to euler angles | |
| # Careful for extreme cases of eular angles like [0.0, pi, 0.0] | |
| sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] + | |
| rot_mats[:, 1, 0] * rot_mats[:, 1, 0]) | |
| return torch.atan2(-rot_mats[:, 2, 0], sy) | |
| class FLAME(nn.Module): | |
| """ | |
| borrowed from https://github.com/soubhiksanyal/FLAME_PyTorch/blob/master/FLAME.py | |
| Given FLAME parameters for shape, pose, and expression, this class generates a differentiable FLAME function | |
| which outputs the a mesh and 2D/3D facial landmarks | |
| """ | |
| def __init__(self, config): | |
| super(FLAME, self).__init__() | |
| with open(f'{env_paths.FLAME_ASSET}', 'rb') as f: | |
| ss = pickle.load(f, encoding='latin1') | |
| flame_model = Struct(**ss) | |
| self.dtype = torch.float32 | |
| self.register_buffer('faces', to_tensor(to_np(flame_model.f, dtype=np.int64), dtype=torch.long)) | |
| # The vertices of the template model | |
| self.register_buffer('v_template', to_tensor(to_np(flame_model.v_template), dtype=self.dtype)) | |
| # The shape components and expression | |
| shapedirs = to_tensor(to_np(flame_model.shapedirs), dtype=self.dtype) | |
| shapedirs = torch.cat([shapedirs[:, :, :config.num_shape_params], shapedirs[:, :, 300:300 + config.num_exp_params]], 2) | |
| self.register_buffer('shapedirs', shapedirs) | |
| # The pose components | |
| num_pose_basis = flame_model.posedirs.shape[-1] | |
| posedirs = np.reshape(flame_model.posedirs, [-1, num_pose_basis]).T | |
| self.register_buffer('posedirs', to_tensor(to_np(posedirs), dtype=self.dtype)) | |
| # | |
| self.register_buffer('J_regressor', to_tensor(to_np(flame_model.J_regressor), dtype=self.dtype)) | |
| parents = to_tensor(to_np(flame_model.kintree_table[0])).long(); | |
| parents[0] = -1 | |
| self.register_buffer('parents', parents) | |
| self.register_buffer('lbs_weights', to_tensor(to_np(flame_model.weights), dtype=self.dtype)) | |
| self.register_buffer('l_eyelid', torch.from_numpy(np.load(f'{os.path.abspath(os.path.dirname(__file__))}/blendshapes/l_eyelid.npy')).to(self.dtype)[None]) | |
| self.register_buffer('r_eyelid', torch.from_numpy(np.load(f'{os.path.abspath(os.path.dirname(__file__))}/blendshapes/r_eyelid.npy')).to(self.dtype)[None]) | |
| # Register default parameters | |
| self._register_default_params('neck_pose_params', 6) | |
| self._register_default_params('jaw_pose_params', 6) | |
| self._register_default_params('eye_pose_params', 12) | |
| self._register_default_params('shape_params', config.num_shape_params) | |
| self._register_default_params('expression_params', config.num_exp_params) | |
| # Static and Dynamic Landmark embeddings for FLAME | |
| lmk_embeddings = np.load(f'{env_paths.FLAME_MASK_ASSET}/FLAME2020/landmark_embedding.npy', allow_pickle=True, encoding='latin1') | |
| lmk_embeddings = lmk_embeddings[()] | |
| self.register_buffer('lmk_faces_idx', torch.from_numpy(lmk_embeddings['static_lmk_faces_idx'].astype(int)).to(torch.int64)) | |
| self.register_buffer('lmk_bary_coords', torch.from_numpy(lmk_embeddings['static_lmk_bary_coords']).to(self.dtype).float()) | |
| self.register_buffer('dynamic_lmk_faces_idx', torch.from_numpy(np.array(lmk_embeddings['dynamic_lmk_faces_idx']).astype(int)).to(torch.int64)) | |
| self.register_buffer('dynamic_lmk_bary_coords', torch.from_numpy(np.array(lmk_embeddings['dynamic_lmk_bary_coords'])).to(self.dtype).float()) | |
| neck_kin_chain = [] | |
| NECK_IDX = 1 | |
| curr_idx = torch.tensor(NECK_IDX, dtype=torch.long) | |
| while curr_idx != -1: | |
| neck_kin_chain.append(curr_idx) | |
| curr_idx = self.parents[curr_idx] | |
| self.register_buffer('neck_kin_chain', torch.stack(neck_kin_chain)) | |
| def _find_dynamic_lmk_idx_and_bcoords(self, vertices, pose, dynamic_lmk_faces_idx, | |
| dynamic_lmk_b_coords, | |
| neck_kin_chain, cameras, dtype=torch.float32): | |
| """ | |
| Selects the face contour depending on the reletive position of the head | |
| Input: | |
| vertices: N X num_of_vertices X 3 | |
| pose: N X full pose | |
| dynamic_lmk_faces_idx: The list of contour face indexes | |
| dynamic_lmk_b_coords: The list of contour barycentric weights | |
| neck_kin_chain: The tree to consider for the relative rotation | |
| dtype: Data type | |
| return: | |
| The contour face indexes and the corresponding barycentric weights | |
| """ | |
| batch_size = vertices.shape[0] | |
| aa_pose = torch.index_select(pose.view(batch_size, -1, 6), 1, neck_kin_chain) | |
| rot_mats = rotation_6d_to_matrix(aa_pose.view(-1, 6)).view([batch_size, -1, 3, 3]) | |
| rel_rot_mat = torch.eye(3, device=vertices.device, dtype=dtype).unsqueeze_(dim=0).expand(batch_size, -1, -1) | |
| for idx in range(len(neck_kin_chain)): | |
| rel_rot_mat = torch.bmm(rot_mats[:, idx], rel_rot_mat) | |
| if cameras is not None: | |
| rel_rot_mat = cameras @ rel_rot_mat # Cameras flips z and x, plus multiview needs different lmk sliding per view | |
| y_rot_angle = torch.round(torch.clamp(-rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi, max=39)).to(dtype=torch.long) | |
| neg_mask = y_rot_angle.lt(0).to(dtype=torch.long) | |
| mask = y_rot_angle.lt(-39).to(dtype=torch.long) | |
| neg_vals = mask * 78 + (1 - mask) * (39 - y_rot_angle) | |
| y_rot_angle = (neg_mask * neg_vals + (1 - neg_mask) * y_rot_angle) | |
| dyn_lmk_faces_idx = torch.index_select(dynamic_lmk_faces_idx, 0, y_rot_angle) | |
| dyn_lmk_b_coords = torch.index_select(dynamic_lmk_b_coords, 0, y_rot_angle) | |
| return dyn_lmk_faces_idx, dyn_lmk_b_coords | |
| def _vertices2landmarks(self, vertices, faces, lmk_faces_idx, lmk_bary_coords): | |
| """ | |
| Calculates landmarks by barycentric interpolation | |
| Input: | |
| vertices: torch.tensor NxVx3, dtype = torch.float32 | |
| The tensor of input vertices | |
| faces: torch.tensor (N*F)x3, dtype = torch.long | |
| The faces of the mesh | |
| lmk_faces_idx: torch.tensor N X L, dtype = torch.long | |
| The tensor with the indices of the faces used to calculate the | |
| landmarks. | |
| lmk_bary_coords: torch.tensor N X L X 3, dtype = torch.float32 | |
| The tensor of barycentric coordinates that are used to interpolate | |
| the landmarks | |
| Returns: | |
| landmarks: torch.tensor NxLx3, dtype = torch.float32 | |
| The coordinates of the landmarks for each mesh in the batch | |
| """ | |
| # Extract the indices of the vertices for each face | |
| # NxLx3 | |
| batch_size, num_verts = vertices.shape[:2] | |
| device = vertices.device | |
| lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1).to(torch.long)).view(batch_size, -1, 3) | |
| lmk_faces += torch.arange(batch_size, dtype=torch.long, device=device).view(-1, 1, 1) * num_verts | |
| lmk_vertices = vertices.view(-1, 3)[lmk_faces].view(batch_size, -1, 3, 3) | |
| landmarks = torch.einsum('blfi,blf->bli', [lmk_vertices, lmk_bary_coords]) | |
| return landmarks | |
| def forward(self, shape_params, | |
| cameras, | |
| trans_params=None, | |
| rot_params=None, | |
| neck_pose_params=None, | |
| jaw_pose_params=None, | |
| eye_pose_params=None, | |
| expression_params=None, | |
| eyelid_params=None, | |
| rot_params_lmk_shift = None, | |
| vertex_offsets = None, | |
| ): | |
| """ | |
| Input: | |
| trans_params: N X 3 global translation | |
| rot_params: N X 3 global rotation around the root joint of the kinematic tree (rotation is NOT around the origin!) | |
| neck_pose_params (optional): N X 3 rotation of the head vertices around the neck joint | |
| jaw_pose_params (optional): N X 3 rotation of the jaw | |
| eye_pose_params (optional): N X 6 rotations of left (parameters [0:3]) and right eyeball (parameters [3:6]) | |
| shape_params (optional): N X number of shape parameters | |
| expression_params (optional): N X number of expression parameters | |
| return:d | |
| vertices: N X V X 3 | |
| landmarks: N X number of landmarks X 3 | |
| """ | |
| batch_size = shape_params.shape[0] | |
| I = matrix_to_rotation_6d(torch.cat([torch.eye(3)[None]] * batch_size, dim=0).cuda()) | |
| if trans_params is None: | |
| trans_params = torch.zeros(batch_size, 3).cuda() | |
| if rot_params is None: | |
| rot_params = I.clone() | |
| if rot_params_lmk_shift is None: | |
| rot_params_lmk_shift = rot_params | |
| if neck_pose_params is None: | |
| neck_pose_params = I.clone() | |
| if jaw_pose_params is None: | |
| jaw_pose_params = I.clone() | |
| if eye_pose_params is None: | |
| eye_pose_params = torch.cat([I.clone()] * 2, dim=1) | |
| if shape_params is None: | |
| shape_params = self.shape_params.expand(batch_size, -1) | |
| if expression_params is None: | |
| expression_params = self.expression_params.expand(batch_size, -1) | |
| # Concatenate identity shape and expression parameters | |
| betas = torch.cat([shape_params, expression_params], dim=1) | |
| # The pose vector contains global rotation, and neck, jaw, and eyeball rotations | |
| full_pose = torch.cat([rot_params, neck_pose_params, jaw_pose_params, eye_pose_params], dim=1) | |
| full_pose_no_neck = torch.cat([rot_params, I, jaw_pose_params, eye_pose_params], dim=1) | |
| full_pose_lmk_shift = torch.cat([rot_params_lmk_shift, neck_pose_params, jaw_pose_params, eye_pose_params], dim=1) | |
| # FLAME models shape and expression deformations as vertex offset from the mean face in 'zero pose', called v_template | |
| template_vertices = self.v_template.unsqueeze(0).expand(batch_size, -1, -1) | |
| # Use linear blendskinning to model pose roations | |
| vertices, joint_transforms, v_can = lbs(betas, full_pose, template_vertices, | |
| self.shapedirs, self.posedirs, | |
| self.J_regressor, self.parents, | |
| self.lbs_weights, dtype=self.dtype) | |
| vertices_noneck, _, v_can = lbs(betas, full_pose_no_neck, template_vertices, | |
| self.shapedirs, self.posedirs, | |
| self.J_regressor, self.parents, | |
| self.lbs_weights, dtype=self.dtype) | |
| #if vertex_offsets is not None: | |
| # vertices[:, self.vertex_face_mask, :] = vertices[:, self.vertex_face_mask, :] + vertex_offsets | |
| if eyelid_params is not None: | |
| vertices = vertices + self.r_eyelid.expand(batch_size, -1, -1) * eyelid_params[:, 1:2, None] | |
| vertices = vertices + self.l_eyelid.expand(batch_size, -1, -1) * eyelid_params[:, 0:1, None] | |
| lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).expand(batch_size, -1).contiguous() | |
| lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).expand(batch_size, -1, -1).contiguous() | |
| dyn_lmk_faces_idx, dyn_lmk_bary_coords = self._find_dynamic_lmk_idx_and_bcoords( | |
| vertices, full_pose_lmk_shift, self.dynamic_lmk_faces_idx, | |
| self.dynamic_lmk_bary_coords, | |
| self.neck_kin_chain, cameras, dtype=self.dtype) | |
| lmk_faces_idx = torch.cat([dyn_lmk_faces_idx, lmk_faces_idx], 1) | |
| lmk_bary_coords = torch.cat([dyn_lmk_bary_coords, lmk_bary_coords], 1) | |
| lmk68 = self._vertices2landmarks(vertices, self.faces, lmk_faces_idx, lmk_bary_coords) | |
| # always zero in this code-base | |
| #vertices = vertices + trans_params.unsqueeze(dim=1) | |
| #lmk68 = lmk68 + trans_params.unsqueeze(dim=1) | |
| return vertices, lmk68, joint_transforms, v_can, vertices_noneck | |
| def _register_default_params(self, param_fname, dim): | |
| default_params = torch.zeros([1, dim], dtype=self.dtype, requires_grad=False) | |
| self.register_parameter(param_fname, nn.Parameter(default_params, requires_grad=False)) | |
| class FLAMETex(nn.Module): | |
| def __init__(self, config, texture_mask_index, tex_res): | |
| super(FLAMETex, self).__init__() | |
| tex_space = np.load(config.tex_space_path) | |
| # FLAME texture | |
| if 'tex_dir' in tex_space.files: | |
| mu_key = 'mean' | |
| pc_key = 'tex_dir' | |
| n_pc = 200 | |
| scale = 1 | |
| # BFM to FLAME texture | |
| else: | |
| mu_key = 'MU' | |
| pc_key = 'PC' | |
| n_pc = 199 | |
| scale = 255.0 | |
| texture_mean = tex_space[mu_key].reshape(1, -1) | |
| texture_basis = tex_space[pc_key].reshape(-1, n_pc) | |
| n_tex = config.tex_params | |
| texture_mean = torch.from_numpy(texture_mean).float()[None, ...] * scale | |
| texture_basis = torch.from_numpy(texture_basis[:, :n_tex]).float()[None, ...] * scale | |
| self.texture = None | |
| self.register_buffer('texture_mean', texture_mean) | |
| self.register_buffer('texture_basis', texture_basis) | |
| self.image_size = (512, 512) #config.image_size | |
| #self.check_texture(config) | |
| self.texture_mask_index = texture_mask_index | |
| self.tex_res = tex_res | |
| def check_texture(self, config): | |
| path = os.path.join(config.actor, 'texture.png') | |
| if os.path.exists(path): | |
| self.texture = torch.from_numpy(imread(path)).permute(2, 0, 1).cuda()[None, 0:3, :, :] / 255.0 | |
| def forward(self, texcode, tex_offsets=None): | |
| if self.texture is not None: | |
| return F.interpolate(self.texture, self.image_size, mode='bilinear') | |
| texture = self.texture_mean + (self.texture_basis * texcode[:, None, :]).sum(-1) | |
| texture = texture.reshape(texcode.shape[0], 512, 512, 3).permute(0, 3, 1, 2) | |
| texture = F.interpolate(texture, (self.tex_res, self.tex_res), mode='bilinear') | |
| #texture = F.interpolate(texture, (1024, 1024), mode='bilinear') | |
| texture = texture[:, [2, 1, 0], :, :] | |
| texture = texture / 255. | |
| if tex_offsets is not None: | |
| texture[:, :, self.texture_mask_index[0], self.texture_mask_index[1]] += tex_offsets | |
| #texture = F.interpolate(texture, (512, 512), mode='bilinear') | |
| return texture | |