# -*- coding: utf-8 -*- # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is # holder of all proprietary rights on this computer program. # You can only use this computer program if you have closed # a license agreement with MPG or you get the right to use the computer # program from someone who is authorized to grant you that right. # Any use of the computer program without a valid license is prohibited and # liable to prosecution. # # Copyright©2022 Max-Planck-Gesellschaft zur Förderung # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute # for Intelligent Systems. All rights reserved. # # Contact: mica@tue.mpg.de from abc import abstractmethod import numpy as np import torch import torch.nn as nn from .config import cfg from .renderer import MeshShapeRenderer from .flame import FLAME from .masking import Masking class BaseModel(nn.Module): def __init__(self, config=None, device=None, tag=''): super(BaseModel, self).__init__() if config is None: self.cfg = cfg else: self.cfg = config self.tag = tag self.use_mask = self.cfg.train.use_mask self.device = device self.masking = Masking(config) self.testing = self.cfg.model.testing def initialize(self): self.create_flame(self.cfg.model) self.create_model(self.cfg.model) self.load_model() self.setup_renderer(self.cfg.model) self.create_weights() def create_flame(self, model_cfg): self.flame = FLAME(model_cfg).to(self.device) self.average_face = self.flame.v_template.clone()[None] self.flame.eval() @abstractmethod def create_model(self): return @abstractmethod def create_load(self): return @abstractmethod def model_dict(self): return @abstractmethod def parameters_to_optimize(self): return @abstractmethod def encode(self, images, arcface_images): return @abstractmethod def decode(self, codedict, epoch): pass @abstractmethod def compute_losses(self, input, encoder_output, decoder_output): pass @abstractmethod def compute_masks(self, input, decoder_output): pass def setup_renderer(self, model_cfg): self.render = MeshShapeRenderer(obj_filename=model_cfg.topology_path) self.verts_template_neutral = self.flame.v_template[None] self.verts_template = None self.verts_template_uv = None def create_weights(self): self.vertices_mask = self.masking.get_weights_per_vertex().to(self.device) self.triangle_mask = self.masking.get_weights_per_triangle().to(self.device) def create_template(self, B): with torch.no_grad(): if self.verts_template is None: self.verts_template_neutral = self.flame.v_template[None] pose = torch.zeros(B, self.cfg.model.n_pose, device=self.device) pose[:, 3] = 10.0 * np.pi / 180.0 # 48 self.verts_template, _, _ = self.flame(shape_params=torch.zeros(B, self.cfg.model.n_shape, device=self.device), expression_params=torch.zeros(B, self.cfg.model.n_exp, device=self.device), pose_params=pose) # use template mesh with open mouth if self.verts_template.shape[0] != B: self.verts_template_neutral = self.verts_template_neutral[0:1].repeat(B, 1, 1) self.verts_template = self.verts_template[0:1].repeat(B, 1, 1)