# -*- coding: utf-8 -*- # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is # holder of all proprietary rights on this computer program. # You can only use this computer program if you have closed # a license agreement with MPG or you get the right to use the computer # program from someone who is authorized to grant you that right. # Any use of the computer program without a valid license is prohibited and # liable to prosecution. # # Copyright©2023 Max-Planck-Gesellschaft zur Förderung # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute # for Intelligent Systems. All rights reserved. # # Contact: mica@tue.mpg.de import os import pickle from huggingface_hub import hf_hub_download from pixel3dmm import env_paths import numpy as np # Modified from smplx code for FLAME import torch import torch.nn as nn import torch.nn.functional as F from skimage.io import imread from pixel3dmm.utils.utils_3d import rotation_6d_to_matrix, matrix_to_rotation_6d from pixel3dmm.tracking.flame.lbs import lbs from pixel3dmm import env_paths I = matrix_to_rotation_6d(torch.eye(3)[None].cuda()) def to_tensor(array, dtype=torch.float32): if 'torch.tensor' not in str(type(array)): return torch.tensor(array, dtype=dtype) def to_np(array, dtype=np.float32): if 'scipy.sparse' in str(type(array)): array = array.todense() return np.array(array, dtype=dtype) class Struct(object): def __init__(self, **kwargs): for key, val in kwargs.items(): setattr(self, key, val) def rot_mat_to_euler(rot_mats): # Calculates rotation matrix to euler angles # Careful for extreme cases of eular angles like [0.0, pi, 0.0] sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] + rot_mats[:, 1, 0] * rot_mats[:, 1, 0]) return torch.atan2(-rot_mats[:, 2, 0], sy) class FLAME(nn.Module): """ borrowed from https://github.com/soubhiksanyal/FLAME_PyTorch/blob/master/FLAME.py Given FLAME parameters for shape, pose, and expression, this class generates a differentiable FLAME function which outputs the a mesh and 2D/3D facial landmarks """ def __init__(self, config): super(FLAME, self).__init__() with open(f'{env_paths.FLAME_ASSET}', 'rb') as f: ss = pickle.load(f, encoding='latin1') flame_model = Struct(**ss) self.dtype = torch.float32 self.register_buffer('faces', to_tensor(to_np(flame_model.f, dtype=np.int64), dtype=torch.long)) # The vertices of the template model self.register_buffer('v_template', to_tensor(to_np(flame_model.v_template), dtype=self.dtype)) # The shape components and expression shapedirs = to_tensor(to_np(flame_model.shapedirs), dtype=self.dtype) shapedirs = torch.cat([shapedirs[:, :, :config.num_shape_params], shapedirs[:, :, 300:300 + config.num_exp_params]], 2) self.register_buffer('shapedirs', shapedirs) # The pose components num_pose_basis = flame_model.posedirs.shape[-1] posedirs = np.reshape(flame_model.posedirs, [-1, num_pose_basis]).T self.register_buffer('posedirs', to_tensor(to_np(posedirs), dtype=self.dtype)) # self.register_buffer('J_regressor', to_tensor(to_np(flame_model.J_regressor), dtype=self.dtype)) parents = to_tensor(to_np(flame_model.kintree_table[0])).long(); parents[0] = -1 self.register_buffer('parents', parents) self.register_buffer('lbs_weights', to_tensor(to_np(flame_model.weights), dtype=self.dtype)) self.register_buffer('l_eyelid', torch.from_numpy(np.load(f'{os.path.abspath(os.path.dirname(__file__))}/blendshapes/l_eyelid.npy')).to(self.dtype)[None]) self.register_buffer('r_eyelid', torch.from_numpy(np.load(f'{os.path.abspath(os.path.dirname(__file__))}/blendshapes/r_eyelid.npy')).to(self.dtype)[None]) # Register default parameters self._register_default_params('neck_pose_params', 6) self._register_default_params('jaw_pose_params', 6) self._register_default_params('eye_pose_params', 12) self._register_default_params('shape_params', config.num_shape_params) self._register_default_params('expression_params', config.num_exp_params) # Static and Dynamic Landmark embeddings for FLAME lmk_embeddings = np.load(f'{env_paths.FLAME_MASK_ASSET}/FLAME2020/landmark_embedding.npy', allow_pickle=True, encoding='latin1') lmk_embeddings = lmk_embeddings[()] self.register_buffer('lmk_faces_idx', torch.from_numpy(lmk_embeddings['static_lmk_faces_idx'].astype(int)).to(torch.int64)) self.register_buffer('lmk_bary_coords', torch.from_numpy(lmk_embeddings['static_lmk_bary_coords']).to(self.dtype).float()) self.register_buffer('dynamic_lmk_faces_idx', torch.from_numpy(np.array(lmk_embeddings['dynamic_lmk_faces_idx']).astype(int)).to(torch.int64)) self.register_buffer('dynamic_lmk_bary_coords', torch.from_numpy(np.array(lmk_embeddings['dynamic_lmk_bary_coords'])).to(self.dtype).float()) neck_kin_chain = [] NECK_IDX = 1 curr_idx = torch.tensor(NECK_IDX, dtype=torch.long) while curr_idx != -1: neck_kin_chain.append(curr_idx) curr_idx = self.parents[curr_idx] self.register_buffer('neck_kin_chain', torch.stack(neck_kin_chain)) def _find_dynamic_lmk_idx_and_bcoords(self, vertices, pose, dynamic_lmk_faces_idx, dynamic_lmk_b_coords, neck_kin_chain, cameras, dtype=torch.float32): """ Selects the face contour depending on the reletive position of the head Input: vertices: N X num_of_vertices X 3 pose: N X full pose dynamic_lmk_faces_idx: The list of contour face indexes dynamic_lmk_b_coords: The list of contour barycentric weights neck_kin_chain: The tree to consider for the relative rotation dtype: Data type return: The contour face indexes and the corresponding barycentric weights """ batch_size = vertices.shape[0] aa_pose = torch.index_select(pose.view(batch_size, -1, 6), 1, neck_kin_chain) rot_mats = rotation_6d_to_matrix(aa_pose.view(-1, 6)).view([batch_size, -1, 3, 3]) rel_rot_mat = torch.eye(3, device=vertices.device, dtype=dtype).unsqueeze_(dim=0).expand(batch_size, -1, -1) for idx in range(len(neck_kin_chain)): rel_rot_mat = torch.bmm(rot_mats[:, idx], rel_rot_mat) if cameras is not None: rel_rot_mat = cameras @ rel_rot_mat # Cameras flips z and x, plus multiview needs different lmk sliding per view y_rot_angle = torch.round(torch.clamp(-rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi, max=39)).to(dtype=torch.long) neg_mask = y_rot_angle.lt(0).to(dtype=torch.long) mask = y_rot_angle.lt(-39).to(dtype=torch.long) neg_vals = mask * 78 + (1 - mask) * (39 - y_rot_angle) y_rot_angle = (neg_mask * neg_vals + (1 - neg_mask) * y_rot_angle) dyn_lmk_faces_idx = torch.index_select(dynamic_lmk_faces_idx, 0, y_rot_angle) dyn_lmk_b_coords = torch.index_select(dynamic_lmk_b_coords, 0, y_rot_angle) return dyn_lmk_faces_idx, dyn_lmk_b_coords def _vertices2landmarks(self, vertices, faces, lmk_faces_idx, lmk_bary_coords): """ Calculates landmarks by barycentric interpolation Input: vertices: torch.tensor NxVx3, dtype = torch.float32 The tensor of input vertices faces: torch.tensor (N*F)x3, dtype = torch.long The faces of the mesh lmk_faces_idx: torch.tensor N X L, dtype = torch.long The tensor with the indices of the faces used to calculate the landmarks. lmk_bary_coords: torch.tensor N X L X 3, dtype = torch.float32 The tensor of barycentric coordinates that are used to interpolate the landmarks Returns: landmarks: torch.tensor NxLx3, dtype = torch.float32 The coordinates of the landmarks for each mesh in the batch """ # Extract the indices of the vertices for each face # NxLx3 batch_size, num_verts = vertices.shape[:2] device = vertices.device lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1).to(torch.long)).view(batch_size, -1, 3) lmk_faces += torch.arange(batch_size, dtype=torch.long, device=device).view(-1, 1, 1) * num_verts lmk_vertices = vertices.view(-1, 3)[lmk_faces].view(batch_size, -1, 3, 3) landmarks = torch.einsum('blfi,blf->bli', [lmk_vertices, lmk_bary_coords]) return landmarks def forward(self, shape_params, cameras, trans_params=None, rot_params=None, neck_pose_params=None, jaw_pose_params=None, eye_pose_params=None, expression_params=None, eyelid_params=None, rot_params_lmk_shift = None, vertex_offsets = None, ): """ Input: trans_params: N X 3 global translation rot_params: N X 3 global rotation around the root joint of the kinematic tree (rotation is NOT around the origin!) neck_pose_params (optional): N X 3 rotation of the head vertices around the neck joint jaw_pose_params (optional): N X 3 rotation of the jaw eye_pose_params (optional): N X 6 rotations of left (parameters [0:3]) and right eyeball (parameters [3:6]) shape_params (optional): N X number of shape parameters expression_params (optional): N X number of expression parameters return:d vertices: N X V X 3 landmarks: N X number of landmarks X 3 """ batch_size = shape_params.shape[0] I = matrix_to_rotation_6d(torch.cat([torch.eye(3)[None]] * batch_size, dim=0).cuda()) if trans_params is None: trans_params = torch.zeros(batch_size, 3).cuda() if rot_params is None: rot_params = I.clone() if rot_params_lmk_shift is None: rot_params_lmk_shift = rot_params if neck_pose_params is None: neck_pose_params = I.clone() if jaw_pose_params is None: jaw_pose_params = I.clone() if eye_pose_params is None: eye_pose_params = torch.cat([I.clone()] * 2, dim=1) if shape_params is None: shape_params = self.shape_params.expand(batch_size, -1) if expression_params is None: expression_params = self.expression_params.expand(batch_size, -1) # Concatenate identity shape and expression parameters betas = torch.cat([shape_params, expression_params], dim=1) # The pose vector contains global rotation, and neck, jaw, and eyeball rotations full_pose = torch.cat([rot_params, neck_pose_params, jaw_pose_params, eye_pose_params], dim=1) full_pose_no_neck = torch.cat([rot_params, I, jaw_pose_params, eye_pose_params], dim=1) full_pose_lmk_shift = torch.cat([rot_params_lmk_shift, neck_pose_params, jaw_pose_params, eye_pose_params], dim=1) # FLAME models shape and expression deformations as vertex offset from the mean face in 'zero pose', called v_template template_vertices = self.v_template.unsqueeze(0).expand(batch_size, -1, -1) # Use linear blendskinning to model pose roations vertices, joint_transforms, v_can = lbs(betas, full_pose, template_vertices, self.shapedirs, self.posedirs, self.J_regressor, self.parents, self.lbs_weights, dtype=self.dtype) vertices_noneck, _, v_can = lbs(betas, full_pose_no_neck, template_vertices, self.shapedirs, self.posedirs, self.J_regressor, self.parents, self.lbs_weights, dtype=self.dtype) #if vertex_offsets is not None: # vertices[:, self.vertex_face_mask, :] = vertices[:, self.vertex_face_mask, :] + vertex_offsets if eyelid_params is not None: vertices = vertices + self.r_eyelid.expand(batch_size, -1, -1) * eyelid_params[:, 1:2, None] vertices = vertices + self.l_eyelid.expand(batch_size, -1, -1) * eyelid_params[:, 0:1, None] lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).expand(batch_size, -1).contiguous() lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).expand(batch_size, -1, -1).contiguous() dyn_lmk_faces_idx, dyn_lmk_bary_coords = self._find_dynamic_lmk_idx_and_bcoords( vertices, full_pose_lmk_shift, self.dynamic_lmk_faces_idx, self.dynamic_lmk_bary_coords, self.neck_kin_chain, cameras, dtype=self.dtype) lmk_faces_idx = torch.cat([dyn_lmk_faces_idx, lmk_faces_idx], 1) lmk_bary_coords = torch.cat([dyn_lmk_bary_coords, lmk_bary_coords], 1) lmk68 = self._vertices2landmarks(vertices, self.faces, lmk_faces_idx, lmk_bary_coords) # always zero in this code-base #vertices = vertices + trans_params.unsqueeze(dim=1) #lmk68 = lmk68 + trans_params.unsqueeze(dim=1) return vertices, lmk68, joint_transforms, v_can, vertices_noneck def _register_default_params(self, param_fname, dim): default_params = torch.zeros([1, dim], dtype=self.dtype, requires_grad=False) self.register_parameter(param_fname, nn.Parameter(default_params, requires_grad=False)) class FLAMETex(nn.Module): def __init__(self, config, texture_mask_index, tex_res): super(FLAMETex, self).__init__() tex_space = np.load(config.tex_space_path) # FLAME texture if 'tex_dir' in tex_space.files: mu_key = 'mean' pc_key = 'tex_dir' n_pc = 200 scale = 1 # BFM to FLAME texture else: mu_key = 'MU' pc_key = 'PC' n_pc = 199 scale = 255.0 texture_mean = tex_space[mu_key].reshape(1, -1) texture_basis = tex_space[pc_key].reshape(-1, n_pc) n_tex = config.tex_params texture_mean = torch.from_numpy(texture_mean).float()[None, ...] * scale texture_basis = torch.from_numpy(texture_basis[:, :n_tex]).float()[None, ...] * scale self.texture = None self.register_buffer('texture_mean', texture_mean) self.register_buffer('texture_basis', texture_basis) self.image_size = (512, 512) #config.image_size #self.check_texture(config) self.texture_mask_index = texture_mask_index self.tex_res = tex_res def check_texture(self, config): path = os.path.join(config.actor, 'texture.png') if os.path.exists(path): self.texture = torch.from_numpy(imread(path)).permute(2, 0, 1).cuda()[None, 0:3, :, :] / 255.0 def forward(self, texcode, tex_offsets=None): if self.texture is not None: return F.interpolate(self.texture, self.image_size, mode='bilinear') texture = self.texture_mean + (self.texture_basis * texcode[:, None, :]).sum(-1) texture = texture.reshape(texcode.shape[0], 512, 512, 3).permute(0, 3, 1, 2) texture = F.interpolate(texture, (self.tex_res, self.tex_res), mode='bilinear') #texture = F.interpolate(texture, (1024, 1024), mode='bilinear') texture = texture[:, [2, 1, 0], :, :] texture = texture / 255. if tex_offsets is not None: texture[:, :, self.texture_mask_index[0], self.texture_mask_index[1]] += tex_offsets #texture = F.interpolate(texture, (512, 512), mode='bilinear') return texture