Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is | |
| # holder of all proprietary rights on this computer program. | |
| # You can only use this computer program if you have closed | |
| # a license agreement with MPG or you get the right to use the computer | |
| # program from someone who is authorized to grant you that right. | |
| # Any use of the computer program without a valid license is prohibited and | |
| # liable to prosecution. | |
| # | |
| # Copyright©2023 Max-Planck-Gesellschaft zur Förderung | |
| # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute | |
| # for Intelligent Systems. All rights reserved. | |
| # | |
| # Contact: mica@tue.mpg.de | |
| import os | |
| import pickle | |
| from pixel3dmm import env_paths | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from trimesh import Trimesh | |
| def to_tensor(array, dtype=torch.float32): | |
| if 'torch.tensor' not in str(type(array)): | |
| return torch.tensor(array, dtype=dtype) | |
| def to_np(array, dtype=np.float32): | |
| if 'scipy.sparse' in str(type(array)): | |
| array = array.todense() | |
| return np.array(array, dtype=dtype) | |
| class Struct(object): | |
| def __init__(self, **kwargs): | |
| for key, val in kwargs.items(): | |
| setattr(self, key, val) | |
| class Masking(nn.Module): | |
| def __init__(self, config): | |
| super(Masking, self).__init__() | |
| ROOT_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..') | |
| with open(f'{ROOT_DIR}/data/FLAME2020/FLAME_masks/FLAME_masks.pkl', 'rb') as f: | |
| ss = pickle.load(f, encoding='latin1') | |
| self.masks = Struct(**ss) | |
| with open(f'{env_paths.FLAME_ASSET}', 'rb') as f: | |
| ss = pickle.load(f, encoding='latin1') | |
| flame_model = Struct(**ss) | |
| self.masked_faces = None | |
| self.cfg = config.mask_weights | |
| self.dtype = torch.float32 | |
| self.register_buffer('faces', to_tensor(to_np(flame_model.f, dtype=np.int64), dtype=torch.long)) | |
| self.register_buffer('vertices', to_tensor(to_np(flame_model.v_template), dtype=self.dtype)) | |
| self.neighbours = {} | |
| for f in self.faces.numpy(): | |
| for v in f: | |
| if str(v) not in self.neighbours: | |
| self.neighbours[str(v)] = set() | |
| for a in list(filter(lambda i: i != v, f)): | |
| self.neighbours[str(v)].add(a) | |
| def get_faces(self): | |
| return self.faces | |
| def get_mask_face(self): | |
| return self.masks.face | |
| def get_mask_eyes(self): | |
| left = self.masks.left_eyeball | |
| right = self.masks.right_eyeball | |
| return np.unique(np.concatenate((left, right))) | |
| def get_mask_forehead(self): | |
| return self.masks.forehead | |
| def get_mask_lips(self): | |
| return self.masks.lips | |
| def get_mask_eye_region(self): | |
| return self.masks.eye_region | |
| def get_mask_lr_eye_region(self): | |
| left = self.masks.left_eye_region | |
| right = self.masks.right_eye_region | |
| return np.unique(np.concatenate((left, right, self.get_mask_eyes()))) | |
| def get_mask_nose(self): | |
| return self.masks.nose | |
| def get_mask_ears(self): | |
| left = self.masks.left_ear | |
| right = self.masks.right_ear | |
| return np.unique(np.concatenate((left, right))) | |
| def get_triangle_face_mask(self): | |
| m = self.masks.face | |
| return self.get_triangle_mask(m) | |
| def get_triangle_eyes_mask(self): | |
| m = self.get_mask_eyes() | |
| return self.get_triangle_mask(m) | |
| def get_triangle_whole_mask(self): | |
| m = self.get_whole_mask() | |
| return self.get_triangle_mask(m) | |
| def get_triangle_mask(self, m): | |
| f = self.faces.cpu().numpy() | |
| selected = [] | |
| for i in range(f.shape[0]): | |
| l = f[i] | |
| valid = 0 | |
| for j in range(3): | |
| if l[j] in m: | |
| valid += 1 | |
| if valid == 3: | |
| selected.append(i) | |
| return np.unique(selected) | |
| def make_soft(self, mask, value, degree=4): | |
| soft = [] | |
| mask = set(mask) | |
| for ring in range(degree): | |
| soft_ring = [] | |
| for v in mask.copy(): | |
| for n in self.neighbours[str(v)]: | |
| if n in mask: | |
| continue | |
| soft_ring.append(n) | |
| mask.add(n) | |
| soft.append((soft_ring, value / (ring + 2))) | |
| return soft | |
| def get_binary_triangle_mask(self): | |
| mask = self.get_whole_mask() | |
| faces = self.faces.cpu().numpy() | |
| reduced_faces = [] | |
| for f in faces: | |
| valid = 0 | |
| for v in f: | |
| if v in mask: | |
| valid += 1 | |
| reduced_faces.append(True if valid == 3 else False) | |
| return reduced_faces | |
| def get_masked_faces(self): | |
| if self.masked_faces is None: | |
| faces = self.faces.cpu().numpy() | |
| vertices = self.vertices.cpu().numpy() | |
| m = Trimesh(vertices=vertices, faces=faces, process=False) | |
| m.update_faces(self.get_binary_triangle_mask()) | |
| self.masked_faces = torch.from_numpy(np.array(m.faces)).cuda().long()[None] | |
| return self.masked_faces | |
| def get_weights_per_triangle(self): | |
| mask = torch.ones_like(self.get_faces()[None]).detach() * self.cfg.whole | |
| mask[:, self.get_triangle_eyes_mask(), :] = self.cfg.eyes | |
| mask[:, self.get_triangle_face_mask(), :] = self.cfg.face | |
| return mask[:, :, 0:1] | |
| def get_weights_per_vertex(self): | |
| mask = torch.ones_like(self.vertices[None]).detach() * self.cfg.whole | |
| mask[:, self.get_mask_eyes(), :] = self.cfg.eyes | |
| mask[:, self.get_mask_ears(), :] = self.cfg.ears | |
| mask[:, self.get_mask_face(), :] = self.cfg.face | |
| return mask | |
| def get_masked_mesh(self, vertices, triangle_mask): | |
| if len(vertices.shape) == 2: | |
| vertices = vertices[None] | |
| B, N, V = vertices.shape | |
| faces = self.faces.cpu().numpy() | |
| masked_vertices = torch.empty(0, 0, 3).cuda() | |
| masked_faces = torch.empty(0, 0, 3).cuda() | |
| for i in range(B): | |
| m = Trimesh(vertices=vertices[i].detach().cpu().numpy(), faces=faces, process=False) | |
| m.update_faces(triangle_mask) | |
| m.process() | |
| f = torch.from_numpy(np.array(m.faces)).cuda()[None] | |
| v = torch.from_numpy(np.array(m.vertices)).cuda()[None].float() | |
| if masked_vertices.shape[1] != v.shape[1]: | |
| masked_vertices = torch.empty(0, v.shape[1], 3).cuda() | |
| if masked_faces.shape[1] != f.shape[1]: | |
| masked_faces = torch.empty(0, f.shape[1], 3).cuda() | |
| masked_vertices = torch.cat([masked_vertices, v]) | |
| masked_faces = torch.cat([masked_faces, f]) | |
| return masked_vertices, masked_faces | |