alexnasa's picture
variable name fixed!
1dbf56f
# -*- coding: utf-8 -*-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2023 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: mica@tue.mpg.de
import os
import pickle
from pixel3dmm import env_paths
import numpy as np
import torch
import torch.nn as nn
from trimesh import Trimesh
def to_tensor(array, dtype=torch.float32):
if 'torch.tensor' not in str(type(array)):
return torch.tensor(array, dtype=dtype)
def to_np(array, dtype=np.float32):
if 'scipy.sparse' in str(type(array)):
array = array.todense()
return np.array(array, dtype=dtype)
class Struct(object):
def __init__(self, **kwargs):
for key, val in kwargs.items():
setattr(self, key, val)
class Masking(nn.Module):
def __init__(self, config):
super(Masking, self).__init__()
ROOT_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..')
with open(f'{ROOT_DIR}/data/FLAME2020/FLAME_masks/FLAME_masks.pkl', 'rb') as f:
ss = pickle.load(f, encoding='latin1')
self.masks = Struct(**ss)
with open(f'{env_paths.FLAME_ASSET}', 'rb') as f:
ss = pickle.load(f, encoding='latin1')
flame_model = Struct(**ss)
self.masked_faces = None
self.cfg = config.mask_weights
self.dtype = torch.float32
self.register_buffer('faces', to_tensor(to_np(flame_model.f, dtype=np.int64), dtype=torch.long))
self.register_buffer('vertices', to_tensor(to_np(flame_model.v_template), dtype=self.dtype))
self.neighbours = {}
for f in self.faces.numpy():
for v in f:
if str(v) not in self.neighbours:
self.neighbours[str(v)] = set()
for a in list(filter(lambda i: i != v, f)):
self.neighbours[str(v)].add(a)
def get_faces(self):
return self.faces
def get_mask_face(self):
return self.masks.face
def get_mask_eyes(self):
left = self.masks.left_eyeball
right = self.masks.right_eyeball
return np.unique(np.concatenate((left, right)))
def get_mask_forehead(self):
return self.masks.forehead
def get_mask_lips(self):
return self.masks.lips
def get_mask_eye_region(self):
return self.masks.eye_region
def get_mask_lr_eye_region(self):
left = self.masks.left_eye_region
right = self.masks.right_eye_region
return np.unique(np.concatenate((left, right, self.get_mask_eyes())))
def get_mask_nose(self):
return self.masks.nose
def get_mask_ears(self):
left = self.masks.left_ear
right = self.masks.right_ear
return np.unique(np.concatenate((left, right)))
def get_triangle_face_mask(self):
m = self.masks.face
return self.get_triangle_mask(m)
def get_triangle_eyes_mask(self):
m = self.get_mask_eyes()
return self.get_triangle_mask(m)
def get_triangle_whole_mask(self):
m = self.get_whole_mask()
return self.get_triangle_mask(m)
def get_triangle_mask(self, m):
f = self.faces.cpu().numpy()
selected = []
for i in range(f.shape[0]):
l = f[i]
valid = 0
for j in range(3):
if l[j] in m:
valid += 1
if valid == 3:
selected.append(i)
return np.unique(selected)
def make_soft(self, mask, value, degree=4):
soft = []
mask = set(mask)
for ring in range(degree):
soft_ring = []
for v in mask.copy():
for n in self.neighbours[str(v)]:
if n in mask:
continue
soft_ring.append(n)
mask.add(n)
soft.append((soft_ring, value / (ring + 2)))
return soft
def get_binary_triangle_mask(self):
mask = self.get_whole_mask()
faces = self.faces.cpu().numpy()
reduced_faces = []
for f in faces:
valid = 0
for v in f:
if v in mask:
valid += 1
reduced_faces.append(True if valid == 3 else False)
return reduced_faces
def get_masked_faces(self):
if self.masked_faces is None:
faces = self.faces.cpu().numpy()
vertices = self.vertices.cpu().numpy()
m = Trimesh(vertices=vertices, faces=faces, process=False)
m.update_faces(self.get_binary_triangle_mask())
self.masked_faces = torch.from_numpy(np.array(m.faces)).cuda().long()[None]
return self.masked_faces
def get_weights_per_triangle(self):
mask = torch.ones_like(self.get_faces()[None]).detach() * self.cfg.whole
mask[:, self.get_triangle_eyes_mask(), :] = self.cfg.eyes
mask[:, self.get_triangle_face_mask(), :] = self.cfg.face
return mask[:, :, 0:1]
def get_weights_per_vertex(self):
mask = torch.ones_like(self.vertices[None]).detach() * self.cfg.whole
mask[:, self.get_mask_eyes(), :] = self.cfg.eyes
mask[:, self.get_mask_ears(), :] = self.cfg.ears
mask[:, self.get_mask_face(), :] = self.cfg.face
return mask
def get_masked_mesh(self, vertices, triangle_mask):
if len(vertices.shape) == 2:
vertices = vertices[None]
B, N, V = vertices.shape
faces = self.faces.cpu().numpy()
masked_vertices = torch.empty(0, 0, 3).cuda()
masked_faces = torch.empty(0, 0, 3).cuda()
for i in range(B):
m = Trimesh(vertices=vertices[i].detach().cpu().numpy(), faces=faces, process=False)
m.update_faces(triangle_mask)
m.process()
f = torch.from_numpy(np.array(m.faces)).cuda()[None]
v = torch.from_numpy(np.array(m.vertices)).cuda()[None].float()
if masked_vertices.shape[1] != v.shape[1]:
masked_vertices = torch.empty(0, v.shape[1], 3).cuda()
if masked_faces.shape[1] != f.shape[1]:
masked_faces = torch.empty(0, f.shape[1], 3).cuda()
masked_vertices = torch.cat([masked_vertices, v])
masked_faces = torch.cat([masked_faces, f])
return masked_vertices, masked_faces