alexnasa's picture
variable name fixed
5cf3b22
# -*- coding: utf-8 -*-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2023 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# For commercial licensing contact, please contact ps-license@tuebingen.mpg.de
import pickle
import loguru
import numpy as np
import torch
import torch.nn as nn
from pixel3dmm import env_paths
from .lbs import lbs, batch_rodrigues, vertices2landmarks, rot_mat_to_euler
def to_tensor(array, dtype=torch.float32):
if 'torch.tensor' not in str(type(array)):
return torch.tensor(array, dtype=dtype)
def to_np(array, dtype=np.float32):
if 'scipy.sparse' in str(type(array)):
array = array.todense()
return np.array(array, dtype=dtype)
class Struct(object):
def __init__(self, **kwargs):
for key, val in kwargs.items():
setattr(self, key, val)
class FLAME(nn.Module):
"""
borrowed from https://github.com/soubhiksanyal/FLAME_PyTorch/blob/master/FLAME.py
Given flame parameters this class generates a differentiable FLAME function
which outputs the a mesh and 2D/3D facial landmarks
"""
def __init__(self, config, optimize_basis=False):
super(FLAME, self).__init__()
loguru.logger.info("[FLAME] creating the FLAME Decoder")
with open(f'{env_paths.FLAME_ASSET}', 'rb') as f:
ss = pickle.load(f, encoding='latin1')
flame_model = Struct(**ss)
self.optimize_basis = optimize_basis
self.cfg = config
self.dtype = torch.float32
self.register_buffer('faces_tensor', to_tensor(to_np(flame_model.f, dtype=np.int64), dtype=torch.long))
# The vertices of the template model
self.register_buffer('v_template', to_tensor(to_np(flame_model.v_template), dtype=self.dtype))
self.n_vertices = self.v_template.shape[0]
# The shape components and expression
shapedirs = to_tensor(to_np(flame_model.shapedirs), dtype=self.dtype)
shapedirs = torch.cat([shapedirs[:, :, :config.n_shape], shapedirs[:, :, 300:]], 2)
if optimize_basis:
self.register_parameter('shapedirs', torch.nn.Parameter(shapedirs))
else:
self.register_buffer('shapedirs', shapedirs)
self.n_shape = config.n_shape
# The pose components
num_pose_basis = flame_model.posedirs.shape[-1]
posedirs = np.reshape(flame_model.posedirs, [-1, num_pose_basis]).T
self.register_buffer('posedirs', to_tensor(to_np(posedirs), dtype=self.dtype))
#
self.register_buffer('J_regressor', to_tensor(to_np(flame_model.J_regressor), dtype=self.dtype))
parents = to_tensor(to_np(flame_model.kintree_table[0])).long();
parents[0] = -1
self.register_buffer('parents', parents)
self.register_buffer('lbs_weights', to_tensor(to_np(flame_model.weights), dtype=self.dtype))
# Fixing Eyeball and neck rotation
default_eyball_pose = torch.zeros([1, 6], dtype=self.dtype, requires_grad=False)
self.register_parameter('eye_pose', nn.Parameter(default_eyball_pose, requires_grad=False))
default_neck_pose = torch.zeros([1, 3], dtype=self.dtype, requires_grad=False)
self.register_parameter('neck_pose', nn.Parameter(default_neck_pose, requires_grad=False))
# Static and Dynamic Landmark embeddings for FLAME
lmk_embeddings = np.load(config.flame_lmk_embedding_path, allow_pickle=True, encoding='latin1')
lmk_embeddings = lmk_embeddings[()]
self.register_buffer('lmk_faces_idx', torch.from_numpy(lmk_embeddings['static_lmk_faces_idx']).long())
self.register_buffer('lmk_bary_coords', torch.from_numpy(lmk_embeddings['static_lmk_bary_coords']).to(self.dtype))
self.register_buffer('dynamic_lmk_faces_idx', lmk_embeddings['dynamic_lmk_faces_idx'].long())
self.register_buffer('dynamic_lmk_bary_coords', lmk_embeddings['dynamic_lmk_bary_coords'].to(self.dtype))
self.register_buffer('full_lmk_faces_idx', torch.from_numpy(lmk_embeddings['full_lmk_faces_idx']).long())
self.register_buffer('full_lmk_bary_coords', torch.from_numpy(lmk_embeddings['full_lmk_bary_coords']).to(self.dtype))
neck_kin_chain = [];
NECK_IDX = 1
curr_idx = torch.tensor(NECK_IDX, dtype=torch.long)
while curr_idx != -1:
neck_kin_chain.append(curr_idx)
curr_idx = self.parents[curr_idx]
self.register_buffer('neck_kin_chain', torch.stack(neck_kin_chain))
def _find_dynamic_lmk_idx_and_bcoords(self, pose, dynamic_lmk_faces_idx,
dynamic_lmk_b_coords,
neck_kin_chain, dtype=torch.float32):
"""
Selects the face contour depending on the reletive position of the head
Input:
vertices: N X num_of_vertices X 3
pose: N X full pose
dynamic_lmk_faces_idx: The list of contour face indexes
dynamic_lmk_b_coords: The list of contour barycentric weights
neck_kin_chain: The tree to consider for the relative rotation
dtype: Data type
return:
The contour face indexes and the corresponding barycentric weights
"""
batch_size = pose.shape[0]
aa_pose = torch.index_select(pose.view(batch_size, -1, 3), 1,
neck_kin_chain)
rot_mats = batch_rodrigues(
aa_pose.view(-1, 3), dtype=dtype).view(batch_size, -1, 3, 3)
rel_rot_mat = torch.eye(3, device=pose.device,
dtype=dtype).unsqueeze_(dim=0).expand(batch_size, -1, -1)
for idx in range(len(neck_kin_chain)):
rel_rot_mat = torch.bmm(rot_mats[:, idx], rel_rot_mat)
y_rot_angle = torch.round(
torch.clamp(rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi,
max=39)).to(dtype=torch.long)
neg_mask = y_rot_angle.lt(0).to(dtype=torch.long)
mask = y_rot_angle.lt(-39).to(dtype=torch.long)
neg_vals = mask * 78 + (1 - mask) * (39 - y_rot_angle)
y_rot_angle = (neg_mask * neg_vals +
(1 - neg_mask) * y_rot_angle)
dyn_lmk_faces_idx = torch.index_select(dynamic_lmk_faces_idx,
0, y_rot_angle)
dyn_lmk_b_coords = torch.index_select(dynamic_lmk_b_coords,
0, y_rot_angle)
return dyn_lmk_faces_idx, dyn_lmk_b_coords
def _vertices2landmarks(self, vertices, faces, lmk_faces_idx, lmk_bary_coords):
"""
Calculates landmarks by barycentric interpolation
Input:
vertices: torch.tensor NxVx3, dtype = torch.float32
The tensor of input vertices
faces: torch.tensor (N*F)x3, dtype = torch.long
The faces of the mesh
lmk_faces_idx: torch.tensor N X L, dtype = torch.long
The tensor with the indices of the faces used to calculate the
landmarks.
lmk_bary_coords: torch.tensor N X L X 3, dtype = torch.float32
The tensor of barycentric coordinates that are used to interpolate
the landmarks
Returns:
landmarks: torch.tensor NxLx3, dtype = torch.float32
The coordinates of the landmarks for each mesh in the batch
"""
# Extract the indices of the vertices for each face
# NxLx3
batch_size, num_verts = vertices.shape[:dd2]
lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).view(
1, -1, 3).view(batch_size, lmk_faces_idx.shape[1], -1)
lmk_faces += torch.arange(batch_size, dtype=torch.long).view(-1, 1, 1).to(
device=vertices.device) * num_verts
lmk_vertices = vertices.view(-1, 3)[lmk_faces]
landmarks = torch.einsum('blfi,blf->bli', [lmk_vertices, lmk_bary_coords])
return landmarks
# def seletec_3d68(self, vertices):
def compute_landmarks(self, vertices):
landmarks3d = vertices2landmarks(vertices, self.faces_tensor,
self.full_lmk_faces_idx.repeat(vertices.shape[0], 1),
self.full_lmk_bary_coords.repeat(vertices.shape[0], 1, 1))
return landmarks3d
def seletec_3d68(self, vertices):
landmarks3d = vertices2landmarks(vertices, self.faces_tensor,
self.full_lmk_faces_idx.repeat(vertices.shape[0], 1),
self.full_lmk_bary_coords.repeat(vertices.shape[0], 1, 1))
return landmarks3d
def project_to_shape_basis(self, shape_vector, shape_as_offset=False):
batch_size = shape_vector.shape[0]
n_vertices = self.v_template.shape[0]
n_eigenvectors = self.n_shape
# shape_params = basis dot (shape_vector - average) # uses properties of the PCA
if shape_as_offset:
diff = shape_vector
else:
diff = shape_vector - self.v_template
return torch.matmul(diff.reshape(batch_size, -1), self.shapedirs[:, :, :n_eigenvectors].reshape(3 * n_vertices, n_eigenvectors))
def compute_distance_to_basis(self, shape_vector, shape_as_offset=False):
batch_size = shape_vector.shape[0]
n_vertices = self.v_template.shape[0]
n_eigenvectors = self.n_shape
# shape_vector torch.Size([3, 5023, 3])
# self.v_template torch.Size([5023, 3])
# self.shapedirs torch.Size([5023, 3, 150])
# diff torch.Size([3, 5023, 3])
# shape_params torch.Size([5023, 15069])
# shape_params = basis dot (shape_vector - average) # uses properties of the PCA
if shape_as_offset:
diff = shape_vector
else:
diff = shape_vector - self.v_template
shape_params = torch.matmul(diff.reshape(batch_size, -1), self.shapedirs[:, :, :n_eigenvectors].reshape(3 * n_vertices, n_eigenvectors))
distance = diff - torch.matmul(shape_params, self.shapedirs[:, :, :n_eigenvectors].reshape(n_vertices * 3, n_eigenvectors).t()).reshape(batch_size, n_vertices, 3)
return distance
def get_std(self):
n_eigenvectors = self.cfg.n_shape
basis = self.shapedirs[:, :, :n_eigenvectors]
std = torch.norm(basis.reshape(-1, n_eigenvectors), dim=0)
return std
def compute_closest_shape(self, shape_vector):
B = shape_vector.shape[0]
N = self.v_template.shape[0]
n_eigenvectors = self.cfg.n_shape
basis = self.shapedirs[:, :, :n_eigenvectors]
diff = (shape_vector - self.v_template).reshape(B, -1)
std = torch.norm(basis.reshape(-1, n_eigenvectors), dim=0)
inv = 1.0 / std.square()
params = inv * torch.matmul(diff, basis.reshape(3 * N, n_eigenvectors))
# params = torch.max(torch.min(params, std*-3.0), std*3.0)
return self.v_template + torch.matmul(params, basis.reshape(N * 3, n_eigenvectors).T).reshape(B, N, 3), params
def forward(self, shape_params=None, expression_params=None, pose_params=None, eye_pose_params=None, neck_pose_params=None, shape_basis_delta=None):
"""
Input:
shape_params: N X number of shape parameters
expression_params: N X number of expression parameters
pose_params: N X number of pose parameters (6)
return:d
vertices: N X V X 3
landmarks: N X number of landmarks X 3
"""
batch_size = shape_params.shape[0]
if pose_params is None:
pose_params = self.eye_pose.expand(batch_size, -1)
if eye_pose_params is None:
eye_pose_params = self.eye_pose.expand(batch_size, -1)
if neck_pose_params is None:
neck_pose_params = self.neck_pose.expand(batch_size, -1)
if expression_params is None:
expression_params = torch.zeros([1, 100], dtype=self.dtype, requires_grad=False, device=self.neck_pose.device).expand(batch_size, -1)
betas = torch.cat([shape_params, expression_params], dim=1)
full_pose = torch.cat([pose_params[:, :3], neck_pose_params, pose_params[:, 3:], eye_pose_params], dim=1)
template_vertices = self.v_template.unsqueeze(0).expand(batch_size, -1, -1)
vertices, _ = lbs(betas, full_pose, template_vertices,
self.shapedirs, self.posedirs,
self.J_regressor, self.parents,
self.lbs_weights, dtype=self.dtype)
lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).expand(batch_size, -1)
lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).expand(batch_size, -1, -1)
dyn_lmk_faces_idx, dyn_lmk_bary_coords = self._find_dynamic_lmk_idx_and_bcoords(
full_pose, self.dynamic_lmk_faces_idx,
self.dynamic_lmk_bary_coords,
self.neck_kin_chain, dtype=self.dtype)
lmk_faces_idx = torch.cat([dyn_lmk_faces_idx, lmk_faces_idx], 1)
lmk_bary_coords = torch.cat([dyn_lmk_bary_coords, lmk_bary_coords], 1)
landmarks2d = vertices2landmarks(vertices, self.faces_tensor,
lmk_faces_idx,
lmk_bary_coords)
bz = vertices.shape[0]
landmarks3d = vertices2landmarks(vertices, self.faces_tensor,
self.full_lmk_faces_idx.repeat(bz, 1),
self.full_lmk_bary_coords.repeat(bz, 1, 1))
return vertices, landmarks2d, landmarks3d