File size: 3,460 Bytes
bd096d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# -*- coding: utf-8 -*-

# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2023 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: mica@tue.mpg.de


from abc import abstractmethod

import numpy as np
import torch
import torch.nn as nn

from configs.config import cfg
from models.flame import FLAME
from utils.masking import Masking


class BaseModel(nn.Module):
    def __init__(self, config=None, device=None, tag=''):
        super(BaseModel, self).__init__()
        if config is None:
            self.cfg = cfg
        else:
            self.cfg = config

        self.tag = tag
        self.use_mask = self.cfg.train.use_mask
        self.device = device
        self.masking = Masking(config)
        self.testing = self.cfg.model.testing

    def initialize(self):
        self.create_flame(self.cfg.model)
        self.create_model(self.cfg.model)
        self.load_model()
        self.setup_renderer(self.cfg.model)

        self.create_weights()

    def create_flame(self, model_cfg):
        self.flame = FLAME(model_cfg).to(self.device)
        self.average_face = self.flame.v_template.clone()[None]

        self.flame.eval()

    @abstractmethod
    def create_model(self):
        return

    @abstractmethod
    def create_load(self):
        return

    @abstractmethod
    def model_dict(self):
        return

    @abstractmethod
    def parameters_to_optimize(self):
        return

    @abstractmethod
    def encode(self, images, arcface_images):
        return

    @abstractmethod
    def decode(self, codedict, epoch):
        pass

    @abstractmethod
    def compute_losses(self, input, encoder_output, decoder_output):
        pass

    @abstractmethod
    def compute_masks(self, input, decoder_output):
        pass

    def setup_renderer(self, model_cfg):
        self.verts_template_neutral = self.flame.v_template[None]
        self.verts_template = None
        self.verts_template_uv = None

    def create_weights(self):
        self.vertices_mask = self.masking.get_weights_per_vertex().to(self.device)
        self.triangle_mask = self.masking.get_weights_per_triangle().to(self.device)

    def create_template(self, B):
        with torch.no_grad():
            if self.verts_template is None:
                self.verts_template_neutral = self.flame.v_template[None]
                pose = torch.zeros(B, self.cfg.model.n_pose, device=self.device)
                pose[:, 3] = 10.0 * np.pi / 180.0  # 48
                self.verts_template, _, _ = self.flame(shape_params=torch.zeros(B, self.cfg.model.n_shape, device=self.device), expression_params=torch.zeros(B, self.cfg.model.n_exp, device=self.device), pose_params=pose)  # use template mesh with open mouth

            if self.verts_template.shape[0] != B:
                self.verts_template_neutral = self.verts_template_neutral[0:1].repeat(B, 1, 1)
                self.verts_template = self.verts_template[0:1].repeat(B, 1, 1)