| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | from abc import abstractmethod |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| |
|
| | from .config import cfg |
| | from .renderer import MeshShapeRenderer |
| | from .flame import FLAME |
| | from .masking import Masking |
| |
|
| |
|
| | class BaseModel(nn.Module): |
| | def __init__(self, config=None, device=None, tag=''): |
| | super(BaseModel, self).__init__() |
| | if config is None: |
| | self.cfg = cfg |
| | else: |
| | self.cfg = config |
| |
|
| | self.tag = tag |
| | self.use_mask = self.cfg.train.use_mask |
| | self.device = device |
| | self.masking = Masking(config) |
| | self.testing = self.cfg.model.testing |
| |
|
| | def initialize(self): |
| | self.create_flame(self.cfg.model) |
| | self.create_model(self.cfg.model) |
| | self.load_model() |
| | self.setup_renderer(self.cfg.model) |
| |
|
| | self.create_weights() |
| |
|
| | def create_flame(self, model_cfg): |
| | self.flame = FLAME(model_cfg).to(self.device) |
| | self.average_face = self.flame.v_template.clone()[None] |
| |
|
| | self.flame.eval() |
| |
|
| | @abstractmethod |
| | def create_model(self): |
| | return |
| |
|
| | @abstractmethod |
| | def create_load(self): |
| | return |
| |
|
| | @abstractmethod |
| | def model_dict(self): |
| | return |
| |
|
| | @abstractmethod |
| | def parameters_to_optimize(self): |
| | return |
| |
|
| | @abstractmethod |
| | def encode(self, images, arcface_images): |
| | return |
| |
|
| | @abstractmethod |
| | def decode(self, codedict, epoch): |
| | pass |
| |
|
| | @abstractmethod |
| | def compute_losses(self, input, encoder_output, decoder_output): |
| | pass |
| |
|
| | @abstractmethod |
| | def compute_masks(self, input, decoder_output): |
| | pass |
| |
|
| | def setup_renderer(self, model_cfg): |
| | self.render = MeshShapeRenderer(obj_filename=model_cfg.topology_path) |
| | self.verts_template_neutral = self.flame.v_template[None] |
| | self.verts_template = None |
| | self.verts_template_uv = None |
| |
|
| | def create_weights(self): |
| | self.vertices_mask = self.masking.get_weights_per_vertex().to(self.device) |
| | self.triangle_mask = self.masking.get_weights_per_triangle().to(self.device) |
| |
|
| | def create_template(self, B): |
| | with torch.no_grad(): |
| | if self.verts_template is None: |
| | self.verts_template_neutral = self.flame.v_template[None] |
| | pose = torch.zeros(B, self.cfg.model.n_pose, device=self.device) |
| | pose[:, 3] = 10.0 * np.pi / 180.0 |
| | self.verts_template, _, _ = self.flame(shape_params=torch.zeros(B, self.cfg.model.n_shape, device=self.device), expression_params=torch.zeros(B, self.cfg.model.n_exp, device=self.device), pose_params=pose) |
| |
|
| | if self.verts_template.shape[0] != B: |
| | self.verts_template_neutral = self.verts_template_neutral[0:1].repeat(B, 1, 1) |
| | self.verts_template = self.verts_template[0:1].repeat(B, 1, 1) |
| |
|